Skip to content

Commit

Permalink
server: fix incorrectly reported token probabilities (#7125)
Browse files Browse the repository at this point in the history
* server: normalize token probabilities

* fix temperature == 0.0f
  • Loading branch information
JohannesGaessler authored May 7, 2024
1 parent b6aa670 commit af0a5b6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
5 changes: 5 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

result->n_considered = 0;

llama_sampling_set_rng_seed(result, params.seed);

return result;
Expand Down Expand Up @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_considered = 0;
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
Expand Down Expand Up @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
}
}

ctx_sampling->n_considered = cur_p.size;

return id;
}

Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
size_t n_considered;

std::mt19937 rng;
};
Expand Down
2 changes: 1 addition & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ node index.js

`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`

`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0`
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`

`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`

Expand Down
34 changes: 24 additions & 10 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2266,17 +2266,31 @@ struct server_context {
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;

const int32_t n_probs = slot.sparams.n_probs;
if (slot.sparams.temp <= 0 && n_probs > 0) {
// for llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &cur_p);
}
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_considered = slot.ctx_sampling->n_considered;

for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({
cur_p.data[i].id,
cur_p.data[i].p
});
// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}

if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
}

if (!process_token(result, slot)) {
Expand Down

0 comments on commit af0a5b6

Please sign in to comment.