Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add option to output probabilities for completion #1962

Merged
merged 27 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ba210e4
server: add option to output probabilities for completion
WangHaoranRobin Jun 21, 2023
8004e67
Merge pull request #1 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 21, 2023
ccf254b
server: fix comment about max n_probs
WangHaoranRobin Jun 22, 2023
926664c
Merge pull request #2 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 22, 2023
cf76195
server: fix issue when handling probability output for incomplete tok…
WangHaoranRobin Jun 23, 2023
bdb710e
Merge pull request #3 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 23, 2023
7b93b24
server: fix some beginner mistakes
WangHaoranRobin Jun 23, 2023
7cd8fc2
Merge pull request #4 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 23, 2023
6c76c31
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 23, 2023
02c96a4
server: remove trailling white space
WangHaoranRobin Jun 24, 2023
7f7046e
Merge pull request #5 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 24, 2023
23b516b
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 24, 2023
af058cf
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 25, 2023
e815b69
server: remove n_probs upper limit of 5
WangHaoranRobin Jun 25, 2023
bd6550b
Merge pull request #6 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 25, 2023
c9e6642
server: handle probs output when temp=0; handle final response probs …
WangHaoranRobin Jun 25, 2023
13f5d69
Merge branch 'master' into robin_fork_master
WangHaoranRobin Jun 25, 2023
77edee7
Merge pull request #7 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 25, 2023
b5c5c8e
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 26, 2023
c7f7f13
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 27, 2023
bc88fec
server: fix llama_sample_top_k order
WangHaoranRobin Jun 27, 2023
58828c2
Merge pull request #8 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 27, 2023
1d22550
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 28, 2023
ad80773
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jul 1, 2023
1a70a80
examples/common.h: put all bool variables in gpt_params together
WangHaoranRobin Jul 2, 2023
71f8296
examples/common.h: put all bool variables in gpt_params together
WangHaoranRobin Jul 2, 2023
cc3c86f
Merge pull request #9 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jul 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
SlyEcho marked this conversation as resolved.
Show resolved Hide resolved

// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
Expand Down
150 changes: 120 additions & 30 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ struct server_params {
int32_t write_timeout = 600;
};

// completion token output with probabilities
struct completion_token_output {
struct token_prob {
llama_token tok;
float prob;
};

std::vector<token_prob> probs;
llama_token tok;
};

static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
Expand Down Expand Up @@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
fflush(stdout);
}

// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
// if first bit is 1, meaning it's a partial character
if (out.size() > 0 && (out[0] & 0x80) == 0x80) {
std::stringstream ss;
ss<< std::hex << (out[0] & 0xff);
std::string res ( ss.str() );
out = "byte: \\x" + res;
}
return out;
}

// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{ "tok_str", tok_str },
{ "prob", p.prob },
});
}
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
}
return out;
}

static bool server_verbose = false;

#if SERVER_VERBOSE != 1
Expand All @@ -107,6 +152,7 @@ struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_token_output> generated_token_probs;

size_t num_tokens_predicted = 0;
size_t n_past = 0;
Expand Down Expand Up @@ -142,6 +188,7 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
Expand Down Expand Up @@ -221,8 +268,9 @@ struct llama_server_context {
llama_set_rng_seed(ctx, params.seed);
}

llama_token nextToken() {
llama_token result = -1;
completion_token_output nextToken() {
completion_token_output result;
result.tok = -1;

if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
Expand Down Expand Up @@ -261,7 +309,8 @@ struct llama_server_context {

if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
result.tok = llama_token_eos();
return result;
}

// out of user input, sample next token
Expand All @@ -278,7 +327,7 @@ struct llama_server_context {
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
const int32_t n_probs = params.n_probs;

{
auto * logits = llama_get_logits(ctx);
Expand Down Expand Up @@ -312,35 +361,42 @@ struct llama_server_context {

if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
if (n_probs > 0) {
llama_sample_softmax(ctx, &candidates_p);
}
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
size_t min_keep = std::max(1, n_probs);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
SlyEcho marked this conversation as resolved.
Show resolved Hide resolved
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}

for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
}

// add it to the context
embd.push_back(id);
result = id;
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;

Expand Down Expand Up @@ -382,12 +438,16 @@ struct llama_server_context {
return stop_pos;
}

std::string doCompletion() {
const llama_token token = nextToken();
completion_token_output doCompletion() {
const completion_token_output token_with_probs = nextToken();

const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
generated_text += token_text;

if (params.n_probs > 0) {
generated_token_probs.push_back(token_with_probs);
}

if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
Expand Down Expand Up @@ -416,8 +476,8 @@ struct llama_server_context {
}

LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "token", token_with_probs.tok },
{ "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
Expand All @@ -427,7 +487,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word },
});

return token_text;
return token_with_probs;
}

std::vector<float> getEmbedding() {
Expand Down Expand Up @@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
{ "ignore_eos", ignore_eos },
{ "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
};
}

Expand All @@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
};
}

static json format_final_response(llama_server_context & llama, const std::string & content) {
return json {
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {

json res = json {
{ "content", content },
{ "stop", true },
{ "model", llama.params.model_alias },
Expand All @@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word },
};

if (llama.params.n_probs > 0) {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}

return res;
}

static json format_partial_response(const std::string & content) {
return json {
static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
json res = json {
{ "content", content },
{ "stop", false },
};

if (llama.params.n_probs > 0) {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}

return res;
}

static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
Expand Down Expand Up @@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) {
Expand Down Expand Up @@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);

stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
Expand All @@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}

const json data = format_final_response(llama, llama.generated_text);
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

llama_print_timings(llama.ctx);

Expand All @@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
} else {
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
if (llama.multibyte_pending > 0) {
continue;
}
Expand All @@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
const std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size();

std::vector<completion_token_output> probs_output = {};

if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}

const json data = llama.has_next_token
? format_partial_response(to_send)
? format_partial_response(llama, to_send, probs_output)
// Generation is done, send extra information.
: format_final_response(llama, to_send);
: format_final_response(llama, to_send, llama.generated_token_probs);

const std::string str =
"data: " +
Expand Down