From 6dcc02d2444c779c18d49c364c5d5c5728b6b484 Mon Sep 17 00:00:00 2001 From: Alexey Parfenov Date: Fri, 16 Feb 2024 11:33:25 +0000 Subject: [PATCH] server : add "samplers" param to control the samplers order (#5494) --- common/common.cpp | 59 ++++++++++++++++++++++++-------------- common/common.h | 2 +- common/sampling.cpp | 2 +- common/sampling.h | 14 ++++----- examples/server/README.md | 2 ++ examples/server/server.cpp | 25 ++++++++++++++++ 6 files changed, 74 insertions(+), 30 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c5e83cc2a9e40..3a92d3797492f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -341,7 +341,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = sampler_types_from_names(sampler_names); + sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); } else if (arg == "--sampling-seq") { if (++i >= argc) { invalid_param = true; @@ -964,7 +964,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str()); + printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n"); + printf(" (default: %s)\n", sampler_type_names.c_str()); printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str()); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); @@ -1133,34 +1134,50 @@ std::vector string_split(std::string input, char separator) { return parts; } -std::vector sampler_types_from_names(const std::vector & names) { +std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + {"top_k", llama_sampler_type::TOP_K}, + {"top_p", llama_sampler_type::TOP_P}, + {"typical_p", llama_sampler_type::TYPICAL_P}, + {"min_p", llama_sampler_type::MIN_P}, + {"tfs_z", llama_sampler_type::TFS_Z}, + {"temperature", llama_sampler_type::TEMPERATURE} + }; + // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_name_map { - {"top_k", llama_sampler_type::TOP_K}, + std::unordered_map sampler_alt_name_map { {"top-k", llama_sampler_type::TOP_K}, - {"top_p", llama_sampler_type::TOP_P}, {"top-p", llama_sampler_type::TOP_P}, {"nucleus", llama_sampler_type::TOP_P}, - {"typical_p", llama_sampler_type::TYPICAL_P}, {"typical-p", llama_sampler_type::TYPICAL_P}, {"typical", llama_sampler_type::TYPICAL_P}, - {"min_p", llama_sampler_type::MIN_P}, {"min-p", llama_sampler_type::MIN_P}, - {"tfs_z", llama_sampler_type::TFS_Z}, {"tfs-z", llama_sampler_type::TFS_Z}, {"tfs", llama_sampler_type::TFS_Z}, - {"temp", llama_sampler_type::TEMP}, - {"temperature", llama_sampler_type::TEMP} + {"temp", llama_sampler_type::TEMPERATURE} }; std::vector sampler_types; sampler_types.reserve(names.size()); - for (const auto& name : names) { - const auto sampler_item = sampler_name_map.find(name); - if (sampler_item != sampler_name_map.end()) { + for (const auto & name : names) + { + auto sampler_item = sampler_canonical_name_map.find(name); + if (sampler_item != sampler_canonical_name_map.end()) + { sampler_types.push_back(sampler_item->second); } + else + { + if (allow_alt_names) + { + sampler_item = sampler_alt_name_map.find(name); + if (sampler_item != sampler_alt_name_map.end()) + { + sampler_types.push_back(sampler_item->second); + } + } + } } return sampler_types; } @@ -1172,7 +1189,7 @@ std::vector sampler_types_from_chars(const std::string & nam {'y', llama_sampler_type::TYPICAL_P}, {'m', llama_sampler_type::MIN_P}, {'f', llama_sampler_type::TFS_Z}, - {'t', llama_sampler_type::TEMP} + {'t', llama_sampler_type::TEMPERATURE} }; std::vector sampler_types; @@ -1188,12 +1205,12 @@ std::vector sampler_types_from_chars(const std::string & nam std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { switch (sampler_type) { - case llama_sampler_type::TOP_K: return "top_k"; - case llama_sampler_type::TFS_Z: return "tfs_z"; - case llama_sampler_type::TYPICAL_P: return "typical_p"; - case llama_sampler_type::TOP_P: return "top_p"; - case llama_sampler_type::MIN_P: return "min_p"; - case llama_sampler_type::TEMP: return "temp"; + case llama_sampler_type::TOP_K: return "top_k"; + case llama_sampler_type::TFS_Z: return "tfs_z"; + case llama_sampler_type::TYPICAL_P: return "typical_p"; + case llama_sampler_type::TOP_P: return "top_p"; + case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::TEMPERATURE: return "temperature"; default : return ""; } } diff --git a/common/common.h b/common/common.h index 74c1369953d48..935771d44ca9c 100644 --- a/common/common.h +++ b/common/common.h @@ -165,7 +165,7 @@ void process_escapes(std::string& input); // String utils // -std::vector sampler_types_from_names(const std::vector & names); +std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector sampler_types_from_chars(const std::string & names_string); std::vector string_split(std::string input, char separator); std::string sampler_type_to_name_string(llama_sampler_type sampler_type); diff --git a/common/sampling.cpp b/common/sampling.cpp index a001750da0ce2..53013138a9eb4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -139,7 +139,7 @@ static void sampler_queue( case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case llama_sampler_type::TEMP: + case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range); diff --git a/common/sampling.h b/common/sampling.h index 2bd6a75d21534..e1279a8941ce0 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -10,12 +10,12 @@ // sampler types enum class llama_sampler_type : char { - TOP_K = 'k', - TOP_P = 'p', - MIN_P = 'm', - TFS_Z = 'f', - TYPICAL_P = 'y', - TEMP = 't' + TOP_K = 'k', + TOP_P = 'p', + MIN_P = 'm', + TFS_Z = 'f', + TYPICAL_P = 'y', + TEMPERATURE = 't' }; // sampling parameters @@ -45,7 +45,7 @@ typedef struct llama_sampling_params { llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, - llama_sampler_type::TEMP + llama_sampler_type::TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/examples/server/README.md b/examples/server/README.md index 8e141d22d1716..249368749ff07 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -204,6 +204,8 @@ node index.js `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) + `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. (default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values) + ### Result JSON - Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0cb802ce851ad..a0b46970b83a9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -672,6 +672,24 @@ struct llama_server_context } } + const auto &samplers_sequence = data.find("samplers"); + if (samplers_sequence != data.end() && samplers_sequence->is_array()) + { + std::vector sampler_names; + for (const auto &sampler_name : *samplers_sequence) + { + if (sampler_name.is_string()) + { + sampler_names.emplace_back(sampler_name); + } + } + slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); + } + else + { + slot->sparams.samplers_sequence = default_sparams.samplers_sequence; + } + if (multimodal) { const auto &images_data = data.find("image_data"); @@ -1026,6 +1044,12 @@ struct llama_server_context const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + std::vector samplers_sequence; + for (const auto &sampler_type : slot.sparams.samplers_sequence) + { + samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type)); + } + return json { {"n_ctx", slot.n_ctx}, {"model", params.model_alias}, @@ -1056,6 +1080,7 @@ struct llama_server_context {"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, {"grammar", slot.sparams.grammar}, + {"samplers", samplers_sequence} }; }