Skip to content

server : allow to specify custom prompt for penalty calculation #3727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,14 @@ llama_token llama_sampling_sample(
}

// apply penalties
if (!prev.empty()) {
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
Expand Down
3 changes: 3 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance

std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this bool use_penalty_prompt_tokens flag?
It seems it can be replaced with !penalty_prompt_tokens.empty()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite. These params are independent. When penalty_prompt_tokens is empty, use_penalty_prompt_tokens can still be true. It will mean that the server must use custom penalty prompt, but it's just empty at the start. Otherwise there will be no distinction between "do not use custom prompt" and "use empty custom prompt".

} llama_sampling_params;

// general sampler context
Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ node index.js

`frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);

`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens (default: `null` = use the original `prompt`).

`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).

`mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).
Expand Down
44 changes: 44 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,42 @@ struct llama_server_context
slot->prompt = "";
}

slot->sparams.penalty_prompt_tokens.clear();
slot->sparams.use_penalty_prompt_tokens = false;
const auto &penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end())
{
if (penalty_prompt->is_string())
{
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
if (slot->params.n_predict > 0)
{
slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
}
slot->sparams.use_penalty_prompt_tokens = true;
}
else if (penalty_prompt->is_array())
{
const auto n_tokens = penalty_prompt->size();
slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
const int n_vocab = llama_n_vocab(model);
for (const auto &penalty_token : *penalty_prompt)
{
if (penalty_token.is_number_integer())
{
const auto tok = penalty_token.get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{
slot->sparams.penalty_prompt_tokens.push_back(tok);
}
}
}
slot->sparams.use_penalty_prompt_tokens = true;
}
}

slot->sparams.logit_bias.clear();

if (json_value(data, "ignore_eos", false))
Expand Down Expand Up @@ -992,6 +1028,12 @@ struct llama_server_context
slot.generated_text += token_str;
slot.has_next_token = true;

if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
{
// we can change penalty_prompt_tokens because it is always created from scratch each request
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
}

// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
Expand Down Expand Up @@ -1183,6 +1225,8 @@ struct llama_server_context
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
Expand Down