From ec893798b7a2a803466cc8f063051499ec3d96f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Sep 2023 19:04:36 +0300 Subject: [PATCH] llama : custom attention mask + parallel decoding + no context swaps (#3228) * tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren --- .gitignore | 2 + Makefile | 8 +- common/common.cpp | 43 +- common/common.h | 8 +- examples/CMakeLists.txt | 2 + examples/baby-llama/baby-llama.cpp | 37 +- examples/batched/CMakeLists.txt | 5 + examples/batched/README.md | 44 + examples/batched/batched.cpp | 246 +++++ examples/beam-search/beam-search.cpp | 5 +- examples/embd-input/embd-input-lib.cpp | 11 +- examples/embedding/embedding.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 8 +- examples/main/main.cpp | 36 +- examples/parallel/CMakeLists.txt | 8 + examples/parallel/README.md | 3 + examples/parallel/parallel.cpp | 380 +++++++ examples/perplexity/perplexity.cpp | 51 +- examples/save-load-state/save-load-state.cpp | 16 +- examples/server/server.cpp | 35 +- examples/simple/README.md | 21 + examples/simple/simple.cpp | 136 ++- examples/speculative/speculative.cpp | 19 +- .../train-text-from-scratch.cpp | 16 +- ggml-cuda.cu | 147 ++- ggml-cuda.h | 1 + ggml-metal.m | 131 ++- ggml-metal.metal | 159 ++- ggml.c | 170 ++- ggml.h | 45 +- llama.cpp | 990 +++++++++++++----- llama.h | 325 ++++-- tests/CMakeLists.txt | 2 + tests/test-grad0.cpp | 14 +- tests/test-rope.cpp | 221 ++++ 35 files changed, 2687 insertions(+), 660 deletions(-) create mode 100644 examples/batched/CMakeLists.txt create mode 100644 examples/batched/README.md create mode 100644 examples/batched/batched.cpp create mode 100644 examples/parallel/CMakeLists.txt create mode 100644 examples/parallel/README.md create mode 100644 examples/parallel/parallel.cpp create mode 100644 examples/simple/README.md create mode 100644 tests/test-rope.cpp diff --git a/.gitignore b/.gitignore index b862a0415f279..b54723a15052d 100644 --- a/.gitignore +++ b/.gitignore @@ -51,7 +51,9 @@ models-mnt /save-load-state /server /simple +/batched /speculative +/parallel /train-text-from-scratch /vdot build-info.h diff --git a/Makefile b/Makefile index f170f22939942..c7f6a808ed379 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative tests/test-c.o +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o # Binaries only useful for tests TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama @@ -519,6 +519,9 @@ main: examples/main/main.cpp build-info.h ggml. simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @@ -565,6 +568,9 @@ beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o co speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifdef LLAMA_METAL metal: examples/metal/metal.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) diff --git a/common/common.cpp b/common/common.cpp index 2597ba06aee16..7c3e11875cb0b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_chunks = std::stoi(argv[i]); + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + } else if (arg == "-ns" || arg == "--sequences") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_sequences = std::stoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.multiline_input = true; } else if (arg == "--simple-io") { params.simple_io = true; + } else if (arg == "-cb" || arg == "--cont-batching") { + params.cont_batching = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -436,8 +450,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_mmap = false; } else if (arg == "--numa") { params.numa = true; - } else if (arg == "--export") { - params.export_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { @@ -456,8 +468,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.logdir.back() != DIRECTORY_SEPARATOR) { params.logdir += DIRECTORY_SEPARATOR; } - } else if (arg == "--perplexity") { - params.perplexity = true; + } else if (arg == "--perplexity" || arg == "--all-logits") { + params.logits_all = true; } else if (arg == "--ppl-stride") { if (++i >= argc) { invalid_param = true; @@ -655,12 +667,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); - printf(" --perplexity compute perplexity over each ctx window of the prompt\n"); + printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); + printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); + printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -685,7 +700,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif - printf(" --export export the computation graph to 'llama.ggml'\n"); printf(" --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -738,7 +752,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; + lparams.logits_all = params.logits_all; lparams.embedding = params.embedding; lparams.rope_freq_base = params.rope_freq_base; lparams.rope_freq_scale = params.rope_freq_scale; @@ -782,8 +796,9 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); - const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); + std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); + llama_kv_cache_tokens_rm(lctx, -1, -1); llama_reset_timings(lctx); } @@ -890,7 +905,7 @@ llama_token llama_sample_token( llama_token id = 0; - float * logits = llama_get_logits(ctx) + idx * n_vocab; + float * logits = llama_get_logits_ith(ctx, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -941,11 +956,11 @@ llama_token llama_sample_token( if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling @@ -953,7 +968,7 @@ llama_token llama_sample_token( llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); llama_sample_typical (ctx, &cur_p, typical_p, 1); llama_sample_top_p (ctx, &cur_p, top_p, 1); - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); { const int n_top = 10; @@ -1182,7 +1197,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); @@ -1256,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", params.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); diff --git a/common/common.h b/common/common.h index 2761503b826c7..16e30b2f5ccbf 100644 --- a/common/common.h +++ b/common/common.h @@ -42,6 +42,8 @@ struct gpt_params { int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors @@ -107,16 +109,16 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool cont_batching = false; // insert new sequences for decoding on-the-fly bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token - bool perplexity = false; // compute perplexity over the prompt + bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool numa = false; // attempt optimizations that help on some NUMA systems - bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation }; @@ -181,7 +183,7 @@ std::string llama_detokenize_bpe( // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits(ctx) + idx * n_vocab +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 884c4276422eb..129cc01163957 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -23,7 +23,9 @@ else() add_subdirectory(train-text-from-scratch) add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(simple) + add_subdirectory(batched) add_subdirectory(speculative) + add_subdirectory(parallel) add_subdirectory(embd-input) add_subdirectory(llama-bench) add_subdirectory(beam-search) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index ed61125eaa4da..b02a80863d010 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -554,6 +554,14 @@ static struct ggml_tensor * forward( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -581,8 +589,8 @@ static struct ggml_tensor * forward( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); // store key and value to memory { @@ -808,9 +816,18 @@ static struct ggml_tensor * forward_batch( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -838,8 +855,8 @@ static struct ggml_tensor * forward_batch( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, n_batch] // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); @@ -1097,6 +1114,14 @@ static struct ggml_tensor * forward_lora( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -1130,7 +1155,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wqb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, @@ -1139,7 +1164,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wkb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); // store key and value to memory { diff --git a/examples/batched/CMakeLists.txt b/examples/batched/CMakeLists.txt new file mode 100644 index 0000000000000..6aa178d4d5911 --- /dev/null +++ b/examples/batched/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET batched) +add_executable(${TARGET} batched.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/batched/README.md b/examples/batched/README.md new file mode 100644 index 0000000000000..5d730331769fb --- /dev/null +++ b/examples/batched/README.md @@ -0,0 +1,44 @@ +# llama.cpp/example/batched + +The example demonstrates batched generation from a given prompt + +```bash +./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4 + +... + +main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113 + + Hello my name is + +main: generating 4 sequences ... + +main: stream 0 finished +main: stream 1 finished +main: stream 2 finished +main: stream 3 finished + +sequence 0: + +Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b + +sequence 1: + +Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between + +sequence 2: + +Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am + +sequence 3: + +Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and + +main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s + +llama_print_timings: load time = 587.00 ms +llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second) +llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second) +llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) +llama_print_timings: total time = 4156.04 ms +``` diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp new file mode 100644 index 0000000000000..4dd1d553d1c18 --- /dev/null +++ b/examples/batched/batched.cpp @@ -0,0 +1,246 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]); + return 1 ; + } + + int n_parallel = 1; + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + if (argc >= 4) { + n_parallel = std::atoi(argv[3]); + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // total length of the sequences including the prompt + const int n_len = 32; + + // init LLM + + llama_backend_init(params.numa); + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301) + ctx_params.n_batch = std::max(n_len, n_parallel); + // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); + LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + fprintf(stderr, "\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); + + // evaluate the initial prompt + batch.n_tokens = tokens_list.size(); + + for (int32_t i = 0; i < batch.n_tokens; i++) { + batch.token[i] = tokens_list[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch, params.n_threads) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // assign the system KV cache to all parallel sequences + // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + for (int32_t i = 1; i < n_parallel; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens); + } + + if (n_parallel > 1) { + LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + } + + // main loop + + // we will store the parallel decoded sequences in this vector + std::vector streams(n_parallel); + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + std::vector i_batch(n_parallel, batch.n_tokens - 1); + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // prepare the next batch + batch.n_tokens = 0; + + // sample the next token for each parallel sequence / stream + for (int32_t i = 0; i < n_parallel; ++i) { + if (i_batch[i] < 0) { + // the stream has already finished + continue; + } + + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits_ith(ctx, i_batch[i]); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + const int top_k = 40; + const float top_p = 0.9f; + const float temp = 0.4f; + + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temp (ctx, &candidates_p, temp); + + const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); + + //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? -> mark the stream as finished + if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { + i_batch[i] = -1; + LOG_TEE("\n"); + if (n_parallel > 1) { + LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); + } + + continue; + } + + // if there is only one stream, we print immediately to stdout + if (n_parallel == 1) { + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); + } + + streams[i] += llama_token_to_piece(ctx, new_token_id); + + // push this new token for next evaluation + batch.token [batch.n_tokens] = new_token_id; + batch.pos [batch.n_tokens] = n_cur; + batch.seq_id[batch.n_tokens] = i; + batch.logits[batch.n_tokens] = true; + + i_batch[i] = batch.n_tokens; + + batch.n_tokens += 1; + + n_decode += 1; + } + + // all streams are finished + if (batch.n_tokens == 0) { + break; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + if (n_parallel > 1) { + LOG_TEE("\n"); + + for (int32_t i = 0; i < n_parallel; ++i) { + LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); + } + } + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 888ae96656770..63da7c3ec02a5 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -158,8 +158,9 @@ int main(int argc, char ** argv) } std::cout << std::flush; - int n_past = llama_get_kv_cache_token_count(ctx); - if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) + int n_past = 0; + + if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); return 1; diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index c995eef3514a0..9bd4d34705cbc 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -80,7 +80,8 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; + if (llama_decode(ctx, batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -101,7 +102,7 @@ bool eval_tokens(void * model, std::vector tokens) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -183,11 +184,11 @@ llama_token sampling_id(struct MyModel* mymodel) { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling @@ -195,7 +196,7 @@ llama_token sampling_id(struct MyModel* mymodel) { llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); id = llama_sample_token(ctx, &candidates_p); } } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 27d605f4e13d6..18cefa237bbc1 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { while (!embd_inp.empty()) { int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2f1a1d9ff5645..058e34d5c275c 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads); n_processed += n_tokens; } } @@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_token token = llama_token_bos(ctx); for (int i = 0; i < n_gen; i++) { - llama_eval(ctx, &token, 1, n_past + i, n_threads); + llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads); } } @@ -977,6 +977,8 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); + llama_kv_cache_tokens_rm(ctx, -1, -1); + // warmup run if (t.n_prompt > 0) { test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); @@ -986,6 +988,8 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { + llama_kv_cache_tokens_rm(ctx, -1, -1); + uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d78112260de08..1ed543cbc627a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -124,7 +124,7 @@ int main(int argc, char ** argv) { console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); - if (params.perplexity) { + if (params.logits_all) { printf("\n************\n"); printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); printf("************\n\n"); @@ -200,15 +200,6 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // export the cgraph and exit - if (params.export_cgraph) { - llama_eval_export(ctx, "llama.ggml"); - llama_free(ctx); - llama_free_model(model); - - return 0; - } - std::string path_session = params.path_prompt_cache; std::vector session_tokens; @@ -508,17 +499,22 @@ int main(int argc, char ** argv) { break; } - const int n_left = n_past - params.n_keep; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - // always keep the first token - BOS - n_past = std::max(1, params.n_keep); - n_past_guidance = std::max(1, params.n_keep + guidance_offset); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - // insert n_left/2 tokens at the start of embd from last_tokens - embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); + n_past -= n_discard; + + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); @@ -580,7 +576,7 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -597,7 +593,7 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/parallel/CMakeLists.txt b/examples/parallel/CMakeLists.txt new file mode 100644 index 0000000000000..0bbf89eaefce6 --- /dev/null +++ b/examples/parallel/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parallel) +add_executable(${TARGET} parallel.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/parallel/README.md b/examples/parallel/README.md new file mode 100644 index 0000000000000..4d0fe5cef12fa --- /dev/null +++ b/examples/parallel/README.md @@ -0,0 +1,3 @@ +# llama.cpp/example/parallel + +Simplified simluation for serving incoming requests in parallel diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp new file mode 100644 index 0000000000000..790189af98876 --- /dev/null +++ b/examples/parallel/parallel.cpp @@ -0,0 +1,380 @@ +// A basic application simulating a server with multiple clients. +// The clients submite requests to the server and they are processed in parallel. + +#include "build-info.h" + +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + + while (start < end && isspace(str[start])) { + start += 1; + } + + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + + return str.substr(start, end - start); +} + +static std::string k_system = +R"(Transcript of a never ending dialog, where the User interacts with an Assistant. +The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User: Recommend a nice restaurant in the area. +Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. +User: Who is Richard Feynman? +Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". +User:)"; + +static std::vector k_prompts = { + "What is the meaning of life?", + "Tell me an interesting fact about llamas.", + "What is the best way to cook a steak?", + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", + "Recommend some interesting books to read.", + "What is the best way to learn a new language?", + "How to get a job at Google?", + "If you could have any superpower, what would it be?", + "I want to learn how to play the piano.", +}; + +struct client { + int32_t id = 0; + + llama_seq_id seq_id = -1; + + llama_token sampled; + + int64_t t_start_prompt; + int64_t t_start_gen; + + int32_t n_prompt = 0; + int32_t n_decoded = 0; + int32_t i_batch = -1; + + std::string input; + std::string prompt; + std::string response; + + std::vector tokens_prev; +}; + +int main(int argc, char ** argv) { + srand(1234); + + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + // number of simultaneous "clients" to simulate + const int32_t n_clients = params.n_parallel; + + // requests to simulate + const int32_t n_seq = params.n_sequences; + + // insert new requests as soon as the previous one is done + const bool cont_batching = params.cont_batching; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("parallel", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + // init llama.cpp + llama_backend_init(params.numa); + + llama_model * model = NULL; + llama_context * ctx = NULL; + + // load the target model + params.logits_all = true; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + fprintf(stderr, "\n\n"); + fflush(stderr); + + const int n_ctx = llama_n_ctx(ctx); + const int n_vocab = llama_n_vocab(ctx); + + std::vector clients(n_clients); + for (size_t i = 0; i < clients.size(); ++i) { + auto & client = clients[i]; + client.id = i; + client.tokens_prev.resize(std::max(256, params.n_predict)); + std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); + } + + std::vector candidates; + candidates.reserve(n_vocab); + + std::vector tokens_system; + tokens_system = ::llama_tokenize(ctx, k_system, true); + const int32_t n_tokens_system = tokens_system.size(); + + llama_seq_id g_seq_id = 0; + + // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple + // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time + llama_batch batch = llama_batch_init(params.n_ctx, 0); + + int32_t n_total_prompt = 0; + int32_t n_total_gen = 0; + int32_t n_cache_miss = 0; + + const auto t_main_start = ggml_time_us(); + + LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); + LOG_TEE("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + LOG_TEE("\n"); + + { + LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); + + batch.n_tokens = n_tokens_system; + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + batch.token[i] = tokens_system[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + if (llama_decode(ctx, batch, params.n_threads) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < n_clients; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + } + + LOG_TEE("\n"); + } + + LOG_TEE("Processing requests ...\n\n"); + + while (true) { + batch.n_tokens = 0; + + // decode any currently ongoing sequences + for (auto & client : clients) { + if (client.seq_id == -1) { + continue; + } + + batch.token [batch.n_tokens] = client.sampled; + batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; + batch.seq_id[batch.n_tokens] = client.id; + batch.logits[batch.n_tokens] = true; + + client.n_decoded += 1; + client.i_batch = batch.n_tokens; + + batch.n_tokens += 1; + } + + if (batch.n_tokens == 0) { + // all sequences have ended - clear the entire KV cache + for (int i = 0; i < n_clients; ++i) { + llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); + } + + LOG_TEE("%s: clearing the KV cache\n", __func__); + } + + // insert new sequences for decoding + if (cont_batching || batch.n_tokens == 0) { + for (auto & client : clients) { + if (client.seq_id == -1 && g_seq_id < n_seq) { + client.seq_id = g_seq_id; + + client.t_start_prompt = ggml_time_us(); + client.t_start_gen = 0; + + client.input = k_prompts[rand() % k_prompts.size()]; + client.prompt = client.input + "\nAssistant:"; + client.response = ""; + + std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); + + // do not prepend BOS because we have a system prompt! + std::vector tokens_prompt; + tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); + + for (size_t i = 0; i < tokens_prompt.size(); ++i) { + batch.token [batch.n_tokens] = tokens_prompt[i]; + batch.pos [batch.n_tokens] = i + n_tokens_system; + batch.seq_id[batch.n_tokens] = client.id; + batch.logits[batch.n_tokens] = false; + batch.n_tokens += 1; + } + + // extract the logits only for the last token + if (batch.n_tokens > 0) { + batch.logits[batch.n_tokens - 1] = true; + } + + client.n_prompt = tokens_prompt.size(); + client.n_decoded = 0; + client.i_batch = batch.n_tokens - 1; + + LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); + + g_seq_id += 1; + + // insert new requests one-by-one + //if (cont_batching) { + // break; + //} + } + } + } + + if (batch.n_tokens == 0) { + break; + } + + // process in chunks of params.n_batch + int32_t n_batch = params.n_batch; + + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + // experiment: process in powers of 2 + //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { + // n_batch /= 2; + // i -= n_batch; + // continue; + //} + + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view, params.n_threads); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + return 1; + } + + LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + + n_cache_miss += 1; + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + + continue; + } + + LOG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + + for (auto & client : clients) { + if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { + continue; + } + + //printf("client %d, seq %d, token %d, pos %d, batch %d\n", + // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); + + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); + + if (client.n_decoded == 1) { + // start measuring generation time after the first token to make sure all concurrent clients + // have their prompt already processed + client.t_start_gen = ggml_time_us(); + } + + // remember which tokens were sampled - used for repetition penalties during sampling + client.tokens_prev.erase(client.tokens_prev.begin()); + client.tokens_prev.push_back(id); + + const std::string token_str = llama_token_to_piece(ctx, id); + client.response += token_str; + client.sampled = id; + + //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", + // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); + + if (client.n_decoded > 2 && + (id == llama_token_eos(ctx) || + (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || + client.response.find("User:") != std::string::npos || + client.response.find('\n') != std::string::npos)) { + // basic reverse prompt + const size_t pos = client.response.find("User:"); + if (pos != std::string::npos) { + client.response = client.response.substr(0, pos); + } + + // delete only the generated part of the sequence, i.e. keep the system prompt in the cache + llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("\033[1mClient %3d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\nResponse: %s\n\n", + client.id, client.seq_id, client.n_prompt, client.n_decoded, + (t_main_end - client.t_start_prompt) / 1e6, + (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6, + n_cache_miss, + ::trim(client.input).c_str(), + ::trim(client.response).c_str()); + + n_total_prompt += client.n_prompt; + n_total_gen += client.n_decoded; + + client.seq_id = -1; + } + + client.i_batch = -1; + } + } + } + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("\n\n"); + LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); + LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); + LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); + LOG_TEE("Cache misses: %6d\n", n_cache_miss); + + LOG_TEE("\n\n"); + + llama_print_timings(ctx); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + fprintf(stderr, "\n\n"); + + return 0; +} diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 2b375e34e7234..de08bd4a185b4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -80,7 +80,9 @@ static void write_logfile( static std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; - for (float v : logits) max_logit = std::max(max_logit, v); + for (float v : logits) { + max_logit = std::max(max_logit, v); + } double sum_exp = 0.0; for (size_t i = 0; i < logits.size(); i++) { // Subtract the maximum logit value from the current logit value for numerical stability @@ -89,15 +91,21 @@ static std::vector softmax(const std::vector& logits) { sum_exp += exp_logit; probs[i] = exp_logit; } - for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + for (size_t i = 0; i < probs.size(); i++) { + probs[i] /= sum_exp; + } return probs; } static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { float max_logit = logits[0]; - for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + } double sum_exp = 0.0; - for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } @@ -108,7 +116,8 @@ static void process_logits( std::mutex mutex; int counter = 0; auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { - double local_nll = 0, local_nll2 = 0; + double local_nll = 0; + double local_nll2 = 0; while (true) { std::unique_lock lock(mutex); int i = counter++; @@ -126,10 +135,13 @@ static void process_logits( prob_history[i] = results.prob; } }; - for (auto & w : workers) w = std::thread(compute); + for (auto & w : workers) { + w = std::thread(compute); + } compute(); - for (auto & w : workers) w.join(); - + for (auto & w : workers) { + w.join(); + } } static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { @@ -152,8 +164,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {std::move(tokens), 0., {}, {}}; } - std::vector logit_history; - std::vector prob_history; + std::vector logit_history; + std::vector prob_history; logit_history.resize(tokens.size()); prob_history.resize(tokens.size()); @@ -195,12 +207,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -320,6 +335,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -332,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par tokens[batch_start] = llama_token_bos(ctx); } - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -402,7 +420,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static std::vector hellaswag_evaluate_tokens( - llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, int n_vocab, int n_thread + llama_context * ctx, std::vector & tokens, int n_past, int n_batch, int n_vocab, int n_thread ) { std::vector result; result.reserve(tokens.size() * n_vocab); @@ -410,7 +428,7 @@ static std::vector hellaswag_evaluate_tokens( for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } @@ -550,6 +568,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_embd.resize(32); } + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); @@ -661,7 +682,7 @@ int main(int argc, char ** argv) { return 1; } - params.perplexity = true; + params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx); if (params.ppl_stride > 0) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 95527bb863524..6e4d40b9e1d6d 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -35,11 +35,11 @@ int main(int argc, char ** argv) { auto last_n_tokens_data = std::vector(params.repeat_last_n, 0); // init - auto model = llama_load_model_from_file(params.model.c_str(), lparams); + auto * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == nullptr) { return 1; } - auto ctx = llama_new_context_with_model(model, lparams); + auto * ctx = llama_new_context_with_model(model, lparams); if (ctx == nullptr) { llama_free_model(model); return 1; @@ -54,7 +54,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx); + auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); std::vector candidates; candidates.reserve(n_vocab); @@ -91,7 +91,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { llama_free(ctx); // make new context - auto ctx2 = llama_new_context_with_model(model, lparams); + auto * ctx2 = llama_new_context_with_model(model, lparams); // Load state (rng, logits, embedding and kv_cache) from file { @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx2); + auto * logits = llama_get_logits(ctx2); auto n_vocab = llama_n_vocab(ctx2); std::vector candidates; candidates.reserve(n_vocab); @@ -151,7 +151,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ebd7f2fc579e9..273eb36f4284d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -381,6 +381,10 @@ struct llama_server_context // compare the evaluated prompt with the new prompt n_past = common_part(embd, prompt_tokens); + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); + embd = prompt_tokens; if (n_past == num_prompt_tokens) { @@ -411,19 +415,27 @@ struct llama_server_context if (embd.size() >= (size_t)params.n_ctx) { - // Reset context - const int n_left = (params.n_ctx - params.n_keep) / 2; + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) + { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; - std::vector new_tokens(embd.begin(), embd.begin() + params.n_keep); - new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); - embd = new_tokens; - n_past = params.n_keep; truncated = true; LOG_VERBOSE("input truncated", { {"n_ctx", params.n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); } @@ -434,7 +446,8 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, @@ -523,13 +536,13 @@ struct llama_server_context { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else @@ -540,7 +553,7 @@ struct llama_server_context llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token(ctx, &candidates_p); } } diff --git a/examples/simple/README.md b/examples/simple/README.md new file mode 100644 index 0000000000000..5d24b1046935c --- /dev/null +++ b/examples/simple/README.md @@ -0,0 +1,21 @@ +# llama.cpp/example/simple + +The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt. + +```bash +./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" + +... + +main: n_len = 32, n_ctx = 2048, n_parallel = 1, n_kv_req = 32 + + Hello my name is Shawn and I'm a 20 year old male from the United States. I'm a 20 year old + +main: decoded 27 tokens in 2.31 s, speed: 11.68 t/s + +llama_print_timings: load time = 579.15 ms +llama_print_timings: sample time = 0.72 ms / 28 runs ( 0.03 ms per token, 38888.89 tokens per second) +llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms per token, 15.25 tokens per second) +llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second) +llama_print_timings: total time = 2891.13 ms +``` diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 440d22ecfb4a8..1616a4a7581a3 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -26,12 +26,18 @@ int main(int argc, char ** argv) { params.prompt = "Hello my name is"; } + // total length of the sequence including the prompt + const int n_len = 32; + // init LLM llama_backend_init(params.numa); llama_context_params ctx_params = llama_context_default_params(); + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); if (model == NULL) { @@ -41,20 +47,31 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); - const int max_context_size = llama_n_ctx(ctx); - const int max_tokens_list_size = max_context_size - 4; + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); - if ((int) tokens_list.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } - fprintf(stderr, "\n\n"); + // print the prompt token-by-token + + fprintf(stderr, "\n"); for (auto id : tokens_list) { fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); @@ -62,63 +79,104 @@ int main(int argc, char ** argv) { fflush(stderr); - // main loop + // create a llama_batch with size 512 + // we use this object to submit token data for decoding - // The LLM keeps a contextual cache memory of previous token evaluation. - // Usually, once this cache is full, it is required to recompute a compressed context based on previous - // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist - // example, we will just stop the loop once this cache is full or once an end of stream is detected. + llama_batch batch = llama_batch_init(512, 0); - const int n_gen = std::min(32, max_context_size); + // evaluate the initial prompt + batch.n_tokens = tokens_list.size(); - while (llama_get_kv_cache_token_count(ctx) < n_gen) { - // evaluate the transformer + for (int32_t i = 0; i < batch.n_tokens; i++) { + batch.token[i] = tokens_list[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch, params.n_threads) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop - tokens_list.clear(); + int n_cur = batch.n_tokens; + int n_decode = 0; + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { // sample the next token + { + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_token new_token_id = 0; + std::vector candidates; + candidates.reserve(n_vocab); - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } - std::vector candidates; - candidates.reserve(n_vocab); + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? + if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); - new_token_id = llama_sample_token_greedy(ctx , &candidates_p); + // prepare the next batch + batch.n_tokens = 0; - // is it an end of stream ? - if (new_token_id == llama_token_eos(ctx)) { - fprintf(stderr, " [end of text]\n"); - break; + // push this new token for next evaluation + batch.token [batch.n_tokens] = new_token_id; + batch.pos [batch.n_tokens] = n_cur; + batch.seq_id[batch.n_tokens] = 0; + batch.logits[batch.n_tokens] = true; + + batch.n_tokens += 1; + + n_decode += 1; } - // print the new token : - printf("%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); + n_cur += 1; - // push this new token for next evaluation - tokens_list.push_back(new_token_id); + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } } + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + llama_free(ctx); llama_free_model(model); llama_backend_free(); - fprintf(stderr, "\n\n"); - return 0; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index aa904183fa2d8..2445d78dc9788 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -37,7 +37,7 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - params.perplexity = true; // HACK: enable logits_all = true + params.logits_all = true; std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); // load the draft model @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); - llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); - llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads); const auto t_enc_end = ggml_time_us(); @@ -134,7 +134,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -172,7 +172,8 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); + llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); ++n_past_dft; // heuristic for n_draft @@ -256,7 +257,8 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); + llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); ++n_past_cur; if (grammar_dft != NULL) { @@ -265,7 +267,8 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); + llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 59c90c7ba654d..5f541a14100e0 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs( } }; + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // rope has so much parameters that we make a custom function for it - auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale] + auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale] (struct ggml_tensor * t) -> struct ggml_tensor * { // not capturing these, to silcence warnings - const int n_past = 0; const int rope_mode = 0; return ggml_rope_custom(ctx, - t, n_past, n_rot, rope_mode, n_ctx, + t, KQ_pos, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale); }; @@ -787,6 +795,8 @@ struct ggml_tensor * llama_build_train_graphs( ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); // input gradient ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + // KQ_pos + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad)); ggml_allocr_alloc(alloc, t36->grad); // gradient tensors (will be set to zero by ggml_graph_reset) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 79e2d313a3f34..29fb7abd4296a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4369,8 +4369,10 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale) { + +template +static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4379,8 +4381,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const int p = has_pos ? pos[i2] : 0; + const float p0 = p*freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4391,8 +4396,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale) { +template +static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4401,8 +4407,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col/2; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const int p = has_pos ? pos[i2] : 0; + const float p0 = p*freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4413,8 +4422,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -4424,11 +4433,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const int row = blockDim.y*blockIdx.y + threadIdx.y; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; const float col_theta_scale = powf(theta_scale, col); - const float p = p0 + p_delta*(row/p_delta_rows); + // FIXME: this is likely wrong + const int p = pos != nullptr ? pos[i2] : 0; - const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; + const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4438,7 +4449,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale; + const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; const float sin_block_theta = sinf(block_theta); const float cos_block_theta = cosf(block_theta); @@ -5389,31 +5400,41 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { +template +static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); + if (pos == nullptr) { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { +template +static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_neox_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); + if (pos == nullptr) { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { GGML_ASSERT(ncols % 4 == 0); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx); + rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); } static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, @@ -6136,14 +6157,16 @@ inline void ggml_cuda_op_rope( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; + const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -6154,19 +6177,38 @@ inline void ggml_cuda_op_rope( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + + const int32_t * pos = nullptr; + if ((mode & 1) == 0) { + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + pos = (const int32_t *) src1_dd; + } const bool is_neox = mode & 2; const bool is_glm = mode & 4; // compute if (is_glm) { - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream); + GGML_ASSERT(false); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } } (void) src1; @@ -6337,7 +6379,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s } } -void ggml_cuda_set_peer_access(const int n_tokens) { +static void ggml_cuda_set_peer_access(const int n_tokens) { static bool peer_access_enabled = false; const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; @@ -6665,27 +6707,27 @@ static void ggml_cuda_op_mul_mat( } } -void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } -void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } -void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } -void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } -void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } -void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } @@ -6706,7 +6748,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } -void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -6735,7 +6777,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); @@ -6769,7 +6811,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -6813,11 +6855,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } } -void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } -void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -6865,29 +6907,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens (void) dst; } -void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_cpy(src0, dst, nullptr); (void) src1; } -void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); } -void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); } -void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); } -void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } -void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; (void) dst; @@ -7010,11 +7052,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { return extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { +static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { if (scratch && g_scratch_size == 0) { return; } + tensor->backend = GGML_BACKEND_GPU; + // recursively assign CUDA buffers until a compute tensor is found if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src[0]->op; @@ -7026,8 +7070,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); } - tensor->backend = GGML_BACKEND_GPU; - if (scratch && no_alloc) { return; } @@ -7112,6 +7154,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) tensor->extra = extra; } +void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); +} + void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { ggml_cuda_assign_buffers_impl(tensor, true, false, false); } diff --git a/ggml-cuda.h b/ggml-cuda.h index a72e82069b9f1..fda704b665623 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); +GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor); GGML_API void ggml_cuda_set_main_device(int main_device); GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); diff --git a/ggml-metal.m b/ggml-metal.m index 654eb67f3392b..b3c463f03ad3d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -103,7 +103,8 @@ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(rope_f32); + GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -293,7 +294,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(rope_f32); + GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -367,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DEL_KERNEL(rope); + GGML_METAL_DEL_KERNEL(rope_f32); + GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); @@ -768,25 +771,59 @@ void ggml_metal_graph_compute( GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; + bool bcast_row = false; - if (ggml_nelements(src1) == ne10) { + int64_t nb = ne00; + + if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; [encoder setComputePipelineState:ctx->pipeline_add_row]; + + bcast_row = true; } else { [encoder setComputePipelineState:ctx->pipeline_add]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - - const int64_t n = ggml_nelements(dst)/4; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } break; case GGML_OP_MUL: { @@ -868,7 +905,7 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - const int nth = 32; + const int nth = MIN(32, ne00); if (ne00%4 == 0) { [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; @@ -921,7 +958,7 @@ void ggml_metal_graph_compute( src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && - ne11 > 1) { + ne11 > 2) { switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; @@ -1132,7 +1169,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 512; + const int nth = MIN(512, ne00); [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1151,7 +1188,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 256; + const int nth = MIN(256, ne00); [encoder setComputePipelineState:ctx->pipeline_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1169,6 +1206,8 @@ void ggml_metal_graph_compute( { GGML_ASSERT((src0t == GGML_TYPE_F32)); + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; @@ -1202,12 +1241,14 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - const int nth = 32; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -1217,38 +1258,44 @@ void ggml_metal_graph_compute( memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - [encoder setComputePipelineState:ctx->pipeline_rope]; + switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; + default: GGML_ASSERT(false); + }; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; + [encoder setBytes:&mode length:sizeof( int) atIndex:21]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: { - const int nth = 32; + const int nth = MIN(1024, ne00); switch (src0t) { case GGML_TYPE_F32: diff --git a/ggml-metal.metal b/ggml-metal.metal index 7f1c3d9ea74bd..5e1af6a092aed 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -24,12 +24,59 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +85,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -806,30 +853,61 @@ kernel void kernel_alibi_f32( } } +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template kernel void kernel_rope( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -839,7 +917,9 @@ kernel void kernel_rope( const bool is_neox = mode & 2; - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; const float theta_0 = freq_scale * (float)p; const float inv_ndims = -1.f/n_dims; @@ -851,11 +931,11 @@ kernel void kernel_rope( const float cos_theta = cos(theta); const float sin_theta = sin(theta); - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const T x0 = src[0]; + const T x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -870,8 +950,8 @@ kernel void kernel_rope( const int64_t i0 = ib*n_dims + ic/2; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -883,6 +963,9 @@ kernel void kernel_rope( } } +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[32]; - const uint16_t kmask1 = 0x3030; - const uint16_t kmask2 = 0x0f0f; + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/4; const int ix = tiisg%4; diff --git a/ggml.c b/ggml.c index a0be068d6c9f7..35751342f9b16 100644 --- a/ggml.c +++ b/ggml.c @@ -6406,6 +6406,54 @@ struct ggml_tensor * ggml_cont_inplace( return ggml_cont_impl(ctx, a, true); } + +// make contiguous, with new shape +GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -6968,7 +7016,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace( static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -6977,7 +7025,10 @@ static struct ggml_tensor * ggml_rope_impl( float xpos_base, bool xpos_down, bool inplace) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + bool is_node = false; if (a->grad) { @@ -6986,7 +7037,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -6996,6 +7047,7 @@ static struct ggml_tensor * ggml_rope_impl( result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7003,55 +7055,55 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); } struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); } struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); } struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); } struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down) { - return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); + return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); } // ggml_rope_back @@ -7059,7 +7111,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -7067,7 +7119,10 @@ struct ggml_tensor * ggml_rope_back( float freq_scale, float xpos_base, bool xpos_down) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; @@ -7078,7 +7133,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -7088,6 +7143,7 @@ struct ggml_tensor * ggml_rope_back( result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -8798,8 +8854,6 @@ static void ggml_compute_forward_add_f32( #else ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous @@ -12456,13 +12510,11 @@ static void ggml_compute_forward_alibi_f16( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past const int ne2 = src0->ne[2]; // n_head -> this is k @@ -12477,7 +12529,7 @@ static void ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12623,8 +12675,8 @@ static void ggml_compute_forward_clamp( static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12634,9 +12686,9 @@ static void ggml_compute_forward_rope_f32( // these two only relevant for xPos RoPE: float xpos_base; - bool xpos_down; + bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -12645,8 +12697,6 @@ static void ggml_compute_forward_rope_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12677,9 +12727,11 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12716,7 +12768,7 @@ static void ggml_compute_forward_rope_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12761,8 +12813,8 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12770,15 +12822,13 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12809,9 +12859,11 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12890,15 +12942,16 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, dst); + ggml_compute_forward_rope_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, dst); + ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; default: { @@ -12912,6 +12965,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -12929,7 +12983,7 @@ static void ggml_compute_forward_rope_back_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); @@ -12938,8 +12992,6 @@ static void ggml_compute_forward_rope_back_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12966,9 +13018,11 @@ static void ggml_compute_forward_rope_back_f32( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12980,7 +13034,7 @@ static void ggml_compute_forward_rope_back_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -13023,6 +13077,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -13033,12 +13088,10 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -13065,9 +13118,11 @@ static void ggml_compute_forward_rope_back_f16( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -13119,15 +13174,16 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, dst); + ggml_compute_forward_rope_back_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, dst); + ggml_compute_forward_rope_back_f32(params, src0, src1, dst); } break; default: { @@ -15864,11 +15920,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ALIBI: { @@ -16506,7 +16562,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16523,7 +16579,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad, ggml_rope_back(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16537,7 +16593,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_ROPE_BACK: { if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16554,7 +16610,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad, ggml_rope_impl(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, diff --git a/ggml.h b/ggml.h index b2251acef271e..73198dc61b519 100644 --- a/ggml.h +++ b/ggml.h @@ -1055,7 +1055,6 @@ extern "C" { size_t nb1, size_t offset); - // a -> b, return view(b) GGML_API struct ggml_tensor * ggml_cpy( struct ggml_context * ctx, @@ -1078,6 +1077,33 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // make contiguous, with new shape + GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view GGML_API struct ggml_tensor * ggml_reshape( @@ -1225,14 +1251,15 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements + // if mode & 1 == 1, skip n_past elements (DEPRECATED) // if mode & 2 == 1, GPT-NeoX style // if mode & 4 == 1, ChatGLM style - // TODO: avoid creating a new tensor every time + // + // b is an int32 vector with size a->ne[2], it contains the positions GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1241,7 +1268,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1250,7 +1277,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1261,7 +1288,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1272,7 +1299,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down); @@ -1282,7 +1309,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, diff --git a/llama.cpp b/llama.cpp index 6e23a0772325d..140533553c93e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -72,6 +72,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -166,13 +167,13 @@ enum llm_arch { }; static std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, }; @@ -1004,7 +1005,29 @@ struct llama_layer { struct ggml_tensor * b3; // ffn_up }; +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + +// ring-buffer of cached KV data struct llama_kv_cache { + bool has_shift = false; + + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + struct ggml_tensor * k = NULL; struct ggml_tensor * v = NULL; @@ -1012,8 +1035,6 @@ struct llama_kv_cache { llama_buffer buf; - int n; // number of tokens currently in the cache - ~llama_kv_cache() { if (ctx) { ggml_free(ctx); @@ -1197,16 +1218,23 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, - int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd_gqa(); - const int n_layer = hparams.n_layer; + const uint32_t n_embd = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_ctx = hparams.n_ctx; const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; + cache.has_shift = false; + + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - cache.n = 0; struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -1227,10 +1255,10 @@ static bool llama_kv_cache_init( (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer + 1) { + if (n_gpu_layers > (int)n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); } - if (n_gpu_layers > n_layer + 2) { + if (n_gpu_layers > (int)n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); } #endif // GGML_USE_CUBLAS @@ -1238,6 +1266,134 @@ static bool llama_kv_cache_init( return true; } +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +static bool llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + cache.head = 0; + n_tested += n_ctx - cache.head; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + } + + return true; +} + +// find how many cells are currently in use +static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 0; +} + +static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { + if (c0 < 0) c0 = 0; + if (c1 < 0) c1 = cache.size; + + for (int32_t i = c0; i < c1; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } +} + +static void llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.erase(seq_id); + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + } + } + } +} + +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (!cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + } +} + +static void llama_kv_cache_seq_shift( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].pos += delta; + if (cache.cells[i].pos < 0) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } else { + cache.has_shift = true; + cache.cells[i].delta = delta; + } + } + } +} + // // model loading and saving // @@ -2426,15 +2582,7 @@ static bool llama_model_load( static struct ggml_cgraph * llm_build_llama( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2458,6 +2606,14 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("n_kv = %d\n", n_kv); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2475,12 +2631,12 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2490,11 +2646,11 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2522,12 +2678,75 @@ static struct ggml_cgraph * llm_build_llama( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2565,33 +2784,33 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2606,7 +2825,7 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2619,25 +2838,25 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -2652,7 +2871,7 @@ static struct ggml_cgraph * llm_build_llama( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -2661,10 +2880,8 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -2755,18 +2972,9 @@ static struct ggml_cgraph * llm_build_llama( return gf; } - static struct ggml_cgraph * llm_build_baichaun( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2790,6 +2998,12 @@ static struct ggml_cgraph * llm_build_baichaun( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2807,12 +3021,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2822,11 +3036,11 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2854,12 +3068,75 @@ static struct ggml_cgraph * llm_build_baichaun( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2901,12 +3178,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); break; case MODEL_13B: - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); - Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); + Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, n_tokens); break; default: GGML_ASSERT(false); @@ -2920,23 +3197,23 @@ static struct ggml_cgraph * llm_build_baichaun( // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2951,7 +3228,7 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2964,8 +3241,8 @@ static struct ggml_cgraph * llm_build_baichaun( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2974,58 +3251,44 @@ static struct ggml_cgraph * llm_build_baichaun( switch (model.type) { case MODEL_7B: - KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); break; case MODEL_13B: - KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); + // TODO: replace with ggml_add() + KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); break; default: GGML_ASSERT(false); } - // KQ_masked = mask_past(KQ_scaled) - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); - // offload_func_kq(KQ_masked); - // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); -#if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); offload_func_v(KQV); ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3118,15 +3381,7 @@ static struct ggml_cgraph * llm_build_baichaun( static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3150,6 +3405,15 @@ static struct ggml_cgraph * llm_build_falcon( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", + // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -3167,12 +3431,12 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3182,11 +3446,11 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -3214,12 +3478,75 @@ static struct ggml_cgraph * llm_build_falcon( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -3276,45 +3603,45 @@ static struct ggml_cgraph * llm_build_falcon( // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for // non-contiguous views is added for the rope operator struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, N, + ctx0, cur, n_embd_head, n_head, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), 0)); offload_func_kq(tmpq); struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * n_head)); offload_func_kq(tmpk); struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * (n_head + n_head_kv)); offload_func_v(tmpv); // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Kcur); { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); offload_func_v(Vcur); offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); @@ -3327,7 +3654,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3338,21 +3665,21 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3367,7 +3694,7 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3425,15 +3752,7 @@ static struct ggml_cgraph * llm_build_falcon( static struct ggml_cgraph * llm_build_starcoder( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3451,7 +3770,11 @@ static struct ggml_cgraph * llm_build_starcoder( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float norm_eps = hparams.f_norm_eps; + const float norm_eps = hparams.f_norm_eps; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; auto & buf_compute = lctx.buf_compute; @@ -3472,12 +3795,12 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_tensor * position; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3487,21 +3810,21 @@ static struct ggml_cgraph * llm_build_starcoder( GGML_ASSERT(false && "not implemented"); #endif - token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, token); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, embd, N * n_embd * ggml_element_size(token)); + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); } } { // Compute position embeddings. - struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_positions); if (!ggml_allocr_is_measure(lctx.alloc)) { - for (int i = 0; i < N; ++i) { - ((int32_t *) inp_positions->data)[i] = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + ((int32_t *) inp_positions->data)[i] = batch.pos[i]; } } ggml_set_name(inp_positions, "inp_positions"); @@ -3509,12 +3832,35 @@ static struct ggml_cgraph * llm_build_starcoder( position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); } + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } inpL = ggml_add(ctx0, token, position); ggml_set_name(inpL, "inpL"); @@ -3530,23 +3876,23 @@ static struct ggml_cgraph * llm_build_starcoder( // Self Attention cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); - struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*n_embd); - struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); struct ggml_tensor * Qcur = tmpq; struct ggml_tensor * Kcur = tmpk; { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -3556,13 +3902,13 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), 0, 2, 1, 3); ggml_set_name(Q, "Q"); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3573,12 +3919,12 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) @@ -3588,7 +3934,7 @@ static struct ggml_cgraph * llm_build_starcoder( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3601,10 +3947,8 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); ggml_set_name(cur, "KQV_merged_contiguous"); } @@ -3654,10 +3998,7 @@ static struct ggml_cgraph * llm_build_starcoder( static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { + const llama_batch & batch) { const auto & model = lctx.model; struct ggml_cgraph * result = NULL; @@ -3665,19 +4006,19 @@ static struct ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: { - result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_llama(lctx, batch); } break; case LLM_ARCH_BAICHUAN: { - result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_baichaun(lctx, batch); } break; case LLM_ARCH_FALCON: { - result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_falcon(lctx, batch); } break; case LLM_ARCH_STARCODER: { - result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_starcoder(lctx, batch); } break; default: GGML_ASSERT(false); @@ -3686,55 +4027,91 @@ static struct ggml_cgraph * llama_build_graph( return result; } -// evaluate the transformer +// decode a batch of tokens by evaluating the transformer // // - lctx: llama context -// - tokens: new batch of tokens to process -// - embd embeddings input -// - n_tokens number of tokens -// - n_past: the context size so far +// - batch: batch to evaluate // - n_threads: number of threads to use // -static bool llama_eval_internal( +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_internal( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past, - int n_threads, - const char * cgraph_fname) { + llama_batch batch, + int n_threads) { + const uint32_t n_tokens = batch.n_tokens; - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + if (n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return -1; + } - GGML_ASSERT(n_tokens > 0); - GGML_ASSERT(n_past >= 0); - // TODO: keep the values of n_batch and n_ctx - // GGML_ASSERT(n_tokens <= n_batch); - // GGML_ASSERT(n_past + n_tokens <= n_ctx); + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const int64_t t_start_us = ggml_time_us(); #ifdef GGML_USE_MPI - ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif GGML_ASSERT(n_threads > 0); - const int N = n_tokens; - const auto & model = lctx.model; const auto & hparams = model.hparams; - const auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.kv_self; GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; + // helpers for smoother batch API transistion + // after deprecating the llama_eval calls, these will be removed + std::vector pos; + std::vector seq_id; + + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } + + batch.pos = pos.data(); + } + + if (batch.seq_id == nullptr) { + seq_id.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + seq_id[i] = batch.all_seq_id; + } + + batch.seq_id = seq_id.data(); + } + + // we always start to search for a free slot from the start of the cache + // TODO: better strategies can be implemented + kv_self.head = 0; + + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return 1; + } + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? + kv_self.n = std::min((int32_t) hparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); + + //printf("kv_self.n = %d\n", kv_self.n); + ggml_allocr_reset(lctx.alloc); - ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); + ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_allocr_alloc_graph(lctx.alloc, gf); @@ -3743,6 +4120,7 @@ static bool llama_eval_internal( ggml_tensor * node = gf->leafs[i]; if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_copy_to_device(node); } } @@ -3761,7 +4139,7 @@ static bool llama_eval_internal( // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering // with the BLAS calls. need a better solution - if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { n_threads = std::min(4, n_threads); } @@ -3800,12 +4178,9 @@ static bool llama_eval_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif - // update kv token count - lctx.kv_self.n = n_past + N; - - if (cgraph_fname) { - ggml_graph_export(gf, cgraph_fname); - } + // update the kv ring buffer + lctx.kv_self.head += n_tokens; + lctx.kv_self.has_shift = false; #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -3822,13 +4197,20 @@ static bool llama_eval_internal( { auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + if (batch.logits) { + logits_out.resize(n_vocab * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); + } + } else if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); } else { - // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } } @@ -3837,20 +4219,27 @@ static bool llama_eval_internal( auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); } // measure the performance only for the single-token evals - if (N == 1) { + if (n_tokens == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (N > 1) { + else if (n_tokens > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += N; + lctx.n_p_eval += n_tokens; } - return true; + // get a more accurate load time, upon first eval + // TODO: fix this + if (!lctx.has_evaluated_once) { + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; + lctx.has_evaluated_once = true; + } + + return 0; } // @@ -4675,6 +5064,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) // sampling // +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); @@ -4883,7 +5279,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c } } -void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { +void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { const int64_t t_start_sample_us = ggml_time_us(); for (size_t i = 0; i < candidates_p->size; ++i) { @@ -4895,6 +5291,10 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array } } +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + llama_sample_temp(ctx, candidates_p, temp); +} + void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; @@ -5324,7 +5724,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); @@ -5398,7 +5798,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -6321,7 +6721,7 @@ struct llama_context * llama_new_context_with_model( // reserve memory for context buffers if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -6354,10 +6754,10 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_measure(tensor_alignment); // build worst-case graph - int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); - int n_past = hparams.n_ctx - n_tokens; + const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch); llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0)); + #ifdef GGML_USE_METAL if (params.n_gpu_layers > 0) { ctx->ctx_metal = ggml_metal_init(1); @@ -6367,8 +6767,8 @@ struct llama_context * llama_new_context_with_model( return NULL; } ggml_metal_log_set_callback(llama_log_callback_default, NULL); - ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } #endif // measure memory requirements for the graph @@ -6383,7 +6783,7 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); #ifdef GGML_USE_METAL if (ctx->ctx_metal) { - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } #endif #ifdef GGML_USE_CUBLAS @@ -6439,8 +6839,10 @@ struct llama_context * llama_new_context_with_model( if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; llama_backend_free(); exit(1); } @@ -6558,16 +6960,27 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha } int llama_get_kv_cache_token_count(const struct llama_context * ctx) { - return ctx->kv_self.n; + return ctx->kv_self.head; } -#define LLAMA_MAX_RNG_STATE (64*1024) +void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { + llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); +} -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - ctx->rng.seed(seed); +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); +} + +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_seq_keep(ctx->kv_self, seq_id); +} + +void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } // Returns the *maximum* size of the state @@ -6655,6 +7068,16 @@ struct llama_data_file_context : llama_data_context { * */ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + // TODO: does not support multi-sequence states + { + const auto & kv_self = ctx->kv_self; + for (uint32_t i = 0; i < kv_self.head; ++i) { + GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); + GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); + GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); + } + } + // copy rng { std::stringstream rng_ss; @@ -6710,7 +7133,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const int n_ctx = hparams.n_ctx; const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); + const int kv_ntok = kv_self.head; data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_ntok, sizeof(kv_ntok)); @@ -6854,7 +7277,8 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_free(cpy_ctx); } - ctx->kv_self.n = kv_ntok; + ctx->kv_self.head = kv_ntok; + ctx->kv_self.size = kv_size; } const size_t nread = inp - src; @@ -6949,64 +7373,100 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi int llama_eval( struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, + llama_token * tokens, + int32_t n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; - } + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } - return 0; + return ret; } int llama_eval_embd( struct llama_context * ctx, - const float * embd, - int n_tokens, + float * embd, + int32_t n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; - } + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; + + const int ret = llama_decode_internal(*ctx, batch, n_threads); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } - return 0; + return ret; } -int llama_eval_export(struct llama_context * ctx, const char * fname) { - const int n_batch = 1; - const int n_ctx = 512 - n_batch; +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, + }; +} - const std::vector tmp(n_batch, llama_token_bos(ctx)); +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { + llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; - if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - return 0; + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.seq_id) free(batch.seq_id); + if (batch.logits) free(batch.logits); +} + +int llama_decode( + struct llama_context * ctx, + struct llama_batch batch, + int n_threads) { + const int ret = llama_decode_internal(*ctx, batch, n_threads); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; } float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + return ctx->logits.data() + i*ctx->model.hparams.n_vocab; +} + float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index 350268b9a94aa..e07c09f146569 100644 --- a/llama.h +++ b/llama.h @@ -37,6 +37,8 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_MAX_RNG_STATE (64*1024) + #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -60,7 +62,9 @@ extern "C" { struct llama_model; struct llama_context; - typedef int llama_token; + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; enum llama_vocab_type { LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece @@ -80,24 +84,24 @@ extern "C" { // model file types enum llama_ftype { LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -116,6 +120,35 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); + // Input data for llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // - seq_id : the sequence to which the respective token belongs + // - logits : if zero, the logits for the respective token will not be output + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + llama_seq_id * seq_id; + int8_t * logits; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + } llama_batch; + struct llama_context_params { uint32_t seed; // RNG seed, -1 for random int32_t n_ctx; // text context @@ -202,6 +235,7 @@ extern "C" { int32_t n_eval; }; + // Helpers for getting default parameters LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -246,8 +280,10 @@ extern "C" { // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -268,7 +304,7 @@ extern "C" { const char * path_lora, const char * path_base_model, int n_threads), - "please use llama_model_apply_lora_from_file instead"); + "use llama_model_apply_lora_from_file instead"); LLAMA_API int llama_model_apply_lora_from_file( const struct llama_model * model, @@ -276,11 +312,53 @@ extern "C" { const char * path_base_model, int n_threads); + // + // KV cache + // + // Returns the number of tokens in the KV cache - LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), + "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + // Remove all tokens data of cells in [c0, c1) + LLAMA_API void llama_kv_cache_tokens_rm( + struct llama_context * ctx, + int32_t c0, + int32_t c1); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + LLAMA_API void llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly + LLAMA_API void llama_kv_cache_seq_shift( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + + // + // State / sessions + // // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens @@ -289,48 +367,100 @@ extern "C" { // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); + LLAMA_API size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); + LLAMA_API size_t llama_set_state_data( + struct llama_context * ctx, + uint8_t * src); // Save/load session file - LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); - LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + LLAMA_API bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + LLAMA_API bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); - // Run the llama inference to obtain the logits and probabilities for the next token. + // + // Decoding + // + + // Run the llama inference to obtain the logits and probabilities for the next token(s). // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls // Returns 0 on success - LLAMA_API int llama_eval( + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval( struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, + llama_token * tokens, + int32_t n_tokens, int n_past, - int n_threads); + int n_threads), + "use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. - LLAMA_API int llama_eval_embd( + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval_embd( struct llama_context * ctx, - const float * embd, - int n_tokens, + float * embd, + int32_t n_tokens, int n_past, - int n_threads); + int n_threads), + "use llama_decode() instead"); - // Export a static computation graph for context of 511 and batch size of 1 - // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these - // parameters here to keep things simple - // IMPORTANT: do not use for anything else other than debugging and testing! - LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Return batch for single sequence of tokens starting at pos_0 + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id); + + // Allocates a batch of tokens on the heap + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + + // Positive return values does not mean a fatal error, but rather a warning. + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // < 0 - error + LLAMA_API int llama_decode( + struct llama_context * ctx, + struct llama_batch batch, + int n_threads); // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row - // Can be mutated in order to change the probabilities of the next token - // Rows: n_tokens + // Logits for which llama_batch.logits[i] == 0 are undefined + // Rows: n_tokens provided with llama_batch // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Logits for the ith token. Equivalent to: + // llama_get_logits(ctx) + i*n_vocab + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -407,11 +537,25 @@ extern "C" { // Sampling functions // + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_repetition_penalty( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float penalty); /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + LLAMA_API void llama_sample_frequency_and_presence_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float alpha_frequency, + float alpha_presence); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. @@ -424,23 +568,54 @@ extern "C" { float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API void llama_sample_softmax( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); + LLAMA_API void llama_sample_top_k( + struct llama_context * ctx, + llama_token_data_array * candidates, + int k, + size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); + LLAMA_API void llama_sample_top_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); + LLAMA_API void llama_sample_tail_free( + struct llama_context * ctx, + llama_token_data_array * candidates, + float z, + size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); - LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + LLAMA_API void llama_sample_typical( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + LLAMA_API void llama_sample_temp( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp); + + LLAMA_API DEPRECATED(void llama_sample_temperature( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp), + "use llama_sample_temp instead"); /// @details Apply constraints from grammar - LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -448,23 +623,41 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + int m, + float * mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat_v2( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); /// @details Selects the token with the highest probability. - LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token_greedy( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Randomly selects a token from the candidates based on their probabilities. - LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + LLAMA_API void llama_grammar_accept_token( + struct llama_context * ctx, + struct llama_grammar * grammar, + llama_token token); // // Beam search @@ -472,9 +665,10 @@ extern "C" { struct llama_beam_view { const llama_token * tokens; + size_t n_tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Callback should set this to true when a beam is at end-of-beam. + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. }; // Passed to beam_search_callback function. @@ -483,9 +677,10 @@ extern "C" { // These pointers are valid only during the synchronous callback, so should not be saved. struct llama_beams_state { struct llama_beam_view * beam_views; + size_t n_beams; // Number of elements in beam_views[]. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. - bool last_call; // True iff this is the last callback invocation. + bool last_call; // True iff this is the last callback invocation. }; // Type of pointer to the beam_search_callback function. @@ -501,10 +696,18 @@ extern "C" { /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. /// @param n_threads Number of threads as passed to llama_eval(). - LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + LLAMA_API void llama_beam_search( + struct llama_context * ctx, + llama_beam_search_callback_fn_t callback, + void * callback_data, + size_t n_beams, + int n_past, + int n_predict, + int n_threads); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 916dc9d055a2d..a19e1376ed389 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -37,6 +37,8 @@ llama_build_and_test_executable(test-llama-grammar.cpp) llama_build_and_test_executable(test-grad0.cpp) # SLOW # llama_build_and_test_executable(test-opt.cpp) # SLOW +llama_build_and_test_executable(test-rope.cpp) + # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 468cde66adc65..7b0c0fcdbb54c 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1404,6 +1404,11 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + for (int i = 0; i < ne2[2]; ++i) { + ((int32_t *) p->data)[i] = n_past + i; + } + ggml_set_param(ctx0, x[0]); const bool skip_past = (mode & 1); @@ -1415,7 +1420,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); @@ -1438,6 +1443,11 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + for (int i = 0; i < ne2[2]; ++i) { + ((int32_t *) p->data)[i] = n_past + i; + } + ggml_set_param(ctx0, x[0]); const bool skip_past = (mode & 1); @@ -1449,7 +1459,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp new file mode 100644 index 0000000000000..26c1f42dc0e95 --- /dev/null +++ b/tests/test-rope.cpp @@ -0,0 +1,221 @@ +#include "ggml.h" + +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wdouble-promotion" +#endif + +#define MAX_NARGS 3 + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define GGML_SILU_FP16 + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +static float frand(void) { + return (float)rand()/(float)RAND_MAX; +} + +static int irand(int n) { + if (n == 0) return 0; + return rand()%n; +} + +static void get_random_dims(int64_t * dims, int ndims) { + dims[0] = dims[1] = dims[2] = dims[3] = 1; + + for (int i = 0; i < ndims; i++) { + dims[i] = 1 + irand(4); + } +} + +static struct ggml_tensor * get_random_tensor_f32( + struct ggml_context * ctx0, + int ndims, + const int64_t ne[], + float fmin, + float fmax) { + struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); + + switch (ndims) { + case 1: + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; + } + break; + case 2: + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + break; + case 3: + for (int i2 = 0; i2 < ne[2]; i2++) { + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + } + break; + case 4: + for (int i3 = 0; i3 < ne[3]; i3++) { + for (int i2 = 0; i2 < ne[2]; i2++) { + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + } + } + break; + default: + assert(false); + }; + + return result; +} + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +int main(int /*argc*/, const char ** /*argv*/) { + struct ggml_init_params params = { + /* .mem_size = */ 128*1024*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, + }; + + std::vector work_buffer; + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_tensor * x; + + // rope f32 + for (int m = 0; m < 3; ++m) { + const int ndims = 4; + + const int64_t n_rot = 128; + const int64_t ne[4] = { 2*n_rot, 32, 73, 1 }; + + const int n_past_0 = 100; + const int n_past_2 = 33; + + struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + + for (int i = 0; i < ne[2]; ++i) { + ((int32_t *) p0->data)[i] = n_past_0 + i; + ((int32_t *) p1->data)[i] = n_past_2 - n_past_0; + ((int32_t *) p2->data)[i] = n_past_2 + i; + } + + // test mode 0, 2, 4 (standard, GPT-NeoX, GLM) + const int mode = m == 0 ? 0 : m == 1 ? 2 : 4; + + x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); + + // 100, 101, 102, ..., 172 + struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024); + // -67, -67, -67, ..., -67 + struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens + + // 33, 34, 35, ..., 105 + struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_build_forward_expand(gf, r0); + ggml_build_forward_expand(gf, r1); + ggml_build_forward_expand(gf, r2); + + ggml_graph_compute_helper(work_buffer, gf, 4); + + // check that r1 and r2 are the same + { + double sum0 = 0.0f; + double sum1 = 0.0f; + double diff = 0.0f; + + const float * r1_data = (float *) r1->data; + const float * r2_data = (float *) r2->data; + + const int n_elements = ggml_nelements(r1); + + for (int i = 0; i < n_elements; ++i) { + sum0 += fabs(r1_data[i]); + sum1 += fabs(r2_data[i]); + diff += fabs(r1_data[i] - r2_data[i]); + //if (fabs(r1_data[i] - r2_data[i]) > 0.0001f) { + // printf("%d: %f %f\n", i, r1_data[i], r2_data[i]); + // printf("diff: %f\n", fabs(r1_data[i] - r2_data[i])); + //} + } + + //for (int i = 4096; i < 4096 + 128; ++i) { + // printf("%f %f\n", r1_data[i], r2_data[i]); + //} + + printf("mode: %d\n", mode); + printf("sum0: %f\n", sum0); + printf("sum1: %f\n", sum1); + printf("diff: %f\n", diff); + printf("rel err: %f\n", diff / sum0); + printf("rel err: %f\n", diff / sum1); + + GGML_ASSERT(diff / sum0 < 0.0001f); + GGML_ASSERT(diff / sum1 < 0.0001f); + } + } + + ggml_free(ctx0); + + return 0; +} +