Skip to content

server: fix issue for multibyte character generation #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 73 additions & 62 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,6 @@ struct server_params {
int32_t write_timeout = 600;
};

// completion string output with probabilities
struct completion_string_output {
struct token_prob {
std::string tok_str;
float prob;
};

std::vector<token_prob> probs;
std::string tok_str;
};

// completion token output with probabilities
struct completion_token_output {
struct token_prob {
Expand Down Expand Up @@ -108,6 +97,36 @@ static void server_log(const char * level, const char * function, int line,
fflush(stdout);
}

// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
const std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
if (out[0] > 127) {
out = "byte: \\x" + std::format("{:x}", out[0]);
}
return out;
}

// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const vector<completion_token_output> probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{ "tok_str", tok_str },
{ "prob", p.prob },
});
}
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
}
return out;
}

static bool server_verbose = false;

#if SERVER_VERBOSE != 1
Expand All @@ -129,7 +148,7 @@ struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_string_output> generated_text_probs;
std::vector<completion_token_output> generated_token_probs;

size_t num_tokens_predicted = 0;
size_t n_past = 0;
Expand Down Expand Up @@ -160,7 +179,7 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text_probs.clear();
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
Expand Down Expand Up @@ -406,22 +425,16 @@ struct llama_server_context {
return stop_pos;
}

completion_string_output doCompletion() {
completion_token_output doCompletion() {
const completion_token_output token_with_probs = nextToken();
completion_string_output result;

const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
result.tok_str = token_text;
generated_text += token_text;

// iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob
for (const auto & prob : token_with_probs.probs) {
const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok);
result.probs.push_back({prob_text, prob.prob});
if (params.n_probs > 0) {
generated_token_probs.push_back(token_with_probs);
}

generated_text_probs.push_back(result);

if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
Expand Down Expand Up @@ -451,7 +464,7 @@ struct llama_server_context {

LOG_VERBOSE("next token", {
{ "token", token_with_probs.tok },
{ "token_text", llama_token_to_str(ctx, token_with_probs.tok) },
{ "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
Expand All @@ -461,7 +474,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word },
});

return result;
return token_with_probs;
}

std::vector<float> getEmbedding() {
Expand Down Expand Up @@ -713,26 +726,10 @@ static json format_embedding_response(llama_server_context & llama) {
};
}

static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_string_output> & probs) {
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {

json completion_probabilities_json = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
probs_for_token.push_back(json {
{ "tok_str", p.tok_str },
{ "prob", p.prob },
});
}
completion_probabilities_json.push_back(json {
{"content", prob.tok_str},
{"probs", probs_for_token},
});
}

return json {
json res = json {
{ "content", content },
{ "completion_probabilities", completion_probabilities_json},
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
Expand All @@ -743,25 +740,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stopped_word", llama.stopped_word },
{ "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word },
};
}

if (llama.params.n_probs > 0) {
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
res["completion_probabilities"] = completion_probabilities_json;
}

return res;
}

static json format_partial_response(const std::string & content, const completion_string_output & probs) {
static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
json res = json {
{ "content", content },
{ "stop", false },
};

// iterate through probs.probs, and add to res
json probs_json = json::array();
for (const auto & prob : probs.probs) {
probs_json.push_back(json {
{ "tok_str", prob.tok_str },
{ "prob", prob.prob },
});
}
if (probs.probs.size() > 0) {
res["probs"] = probs_json;
if (llama.params.n_probs > 0) {
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
res["completion_probabilities"] = completion_probabilities_json;
}

return res;
Expand Down Expand Up @@ -897,8 +894,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos;

while (llama.has_next_token) {
const completion_string_output token_text_with_probs = llama.doCompletion();
const std::string token_text = token_text_with_probs.tok_str;
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);

stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
Expand All @@ -912,7 +909,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}

const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs);
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

llama_print_timings(llama.ctx);

Expand All @@ -921,9 +918,11 @@ int main(int argc, char ** argv) {
} else {
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;

while (llama.has_next_token) {
const completion_string_output token_text_with_probs = llama.doCompletion();
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
if (llama.multibyte_pending > 0) {
continue;
}
Expand All @@ -932,24 +931,36 @@ int main(int argc, char ** argv) {

const std::string str_test = llama.generated_text.substr(pos);
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL);
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(),
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}

const std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size();

std::vector<completion_token_output> probs_output = {};

if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}

const json data = llama.has_next_token
? format_partial_response(to_send, token_text_with_probs)
? format_partial_response(llama, to_send, probs_output)
// Generation is done, send extra information.
: format_final_response(llama, to_send, {token_text_with_probs});
: format_final_response(llama, to_send, probs_output);

const std::string str =
"data: " +
Expand Down