Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : improve infill support and special token detection #9798

Merged
merged 4 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 111 additions & 137 deletions common/arg.cpp

Large diffs are not rendered by default.

18 changes: 17 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <cstdarg>
Expand All @@ -23,10 +24,10 @@
#include <regex>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <thread>

#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
Expand Down Expand Up @@ -400,6 +401,21 @@ std::string common_params_get_system_info(const common_params & params) {
// String utils
//

std::string string_format(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}

std::vector<std::string> string_split(std::string input, char separator) {
std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
Expand Down
19 changes: 16 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,28 @@ void common_init();

std::string common_params_get_system_info(const common_params & params);

bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
bool set_process_priority(enum ggml_sched_priority prio);

//
// String utils
//

#ifdef __GNUC__
#ifdef __MINGW32__
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif

LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
std::string string_format(const char * fmt, ...);

std::vector<std::string> string_split(std::string input, char separator);

std::string string_strip(const std::string & str);
Expand Down
14 changes: 7 additions & 7 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);

GGML_ASSERT(llama_token_prefix(model) >= 0);
GGML_ASSERT(llama_token_suffix(model) >= 0);
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
GGML_ASSERT(llama_token_fim_suf(model) >= 0);

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
Expand All @@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());

const llama_token middle_token = llama_token_middle(model);
const llama_token middle_token = llama_token_fim_mid(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
Expand Down Expand Up @@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
Expand Down
2 changes: 1 addition & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
- `input_prefix`: Set the prefix of the code to infill.
- `input_suffix`: Set the suffix of the code to infill.

It also accepts all the options of `/completion` except `stream` and `prompt`.
It also accepts all the options of `/completion`.

### **GET** `/props`: Get server global properties.

Expand Down
162 changes: 87 additions & 75 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,12 +753,7 @@ struct server_context {
metrics.init();
}

std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
// TODO: currently, we tokenize using special tokens by default
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
// but it's better compared to completely ignoring ChatML and other chat templates
const bool TMP_FORCE_SPECIAL = true;

std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
Expand All @@ -771,10 +766,10 @@ struct server_context {

std::vector<llama_token> p;
if (first) {
p = common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
p = common_tokenize(ctx, s, false, parse_special);
}

prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
Expand All @@ -788,7 +783,7 @@ struct server_context {
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}

return prompt_tokens;
Expand Down Expand Up @@ -1215,7 +1210,7 @@ struct server_context {
slot.params.n_predict, n_ctx_train);
}

SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());

return slot.has_next_token; // continue
}
Expand Down Expand Up @@ -1483,9 +1478,8 @@ struct server_context {
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
data["index"] = 0;
create_task(data, false, nullptr);
}
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
else if (prompt.is_array()) {
} else if (prompt.is_array()) {
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question
Expand All @@ -1510,9 +1504,8 @@ struct server_context {
}
}
}
}
// invalid case
else {
} else {
// invalid case
throw std::runtime_error(error_msg);
}

Expand Down Expand Up @@ -1785,6 +1778,9 @@ struct server_context {
}
slot->cache_tokens.resize(token_count);

// TODO: maybe detokenize the slot->cache_tokens instead?
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);

const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;

Expand Down Expand Up @@ -1971,70 +1967,69 @@ struct server_context {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;

if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
const bool add_bos = llama_add_bos_token(model);
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}

auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);

const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}

prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));

auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());

const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}

prompt_tokens = embd_inp;
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}

// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
switch (slot.cmpl_type) {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}

// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);

prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));

auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;

if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}

embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));

prompt_tokens = std::move(embd_inp);
} break;
}

slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size();

SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);

// print prompt tokens:
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}

// empty prompt passed -> release the slot and send empty response
if (prompt_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
Expand Down Expand Up @@ -2924,7 +2919,23 @@ int main(int argc, char ** argv) {
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
};

const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. ";
}
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "suffix token is missing. ";
}
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. ";
}

if (!err.empty()) {
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
return;
}

json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
};
Expand Down Expand Up @@ -3010,7 +3021,8 @@ int main(int argc, char ** argv) {
if (body.count("content") != 0) {
const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false);
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);

std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);

if (with_pieces) {
for (const auto& token : tokens) {
Expand Down
Loading
Loading