Skip to content

Commit e6de69a

Browse files
authored
Merge pull request #3 from anon998/sse
Add streaming via server-sent events. Has some changes that I didn't make, and I decided I prefer "stream" to "streaming"
2 parents a25f830 + 2533878 commit e6de69a

File tree

1 file changed

+70
-94
lines changed

1 file changed

+70
-94
lines changed

examples/server/server.cpp

Lines changed: 70 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct server_params
1313

1414
struct llama_server_context
1515
{
16-
bool streaming = false;
16+
bool stream = false;
1717
bool has_next_token = false;
1818
std::string generated_text = "";
1919

@@ -35,7 +35,6 @@ struct llama_server_context
3535
std::string stopping_word;
3636

3737
void rewind() {
38-
streaming = false;
3938
params.antiprompt.clear();
4039
num_tokens_predicted = 0;
4140
generated_text = "";
@@ -253,9 +252,6 @@ struct llama_server_context
253252
if (token == -1) {
254253
return "";
255254
}
256-
if(streaming) {
257-
generated_text = "";
258-
}
259255

260256
std::string token_text = llama_token_to_str(ctx, token);
261257
generated_text += token_text;
@@ -270,7 +266,7 @@ struct llama_server_context
270266
}
271267
}
272268

273-
return generated_text;
269+
return token_text;
274270
}
275271

276272
std::vector<float> embedding(std::string content, int threads) {
@@ -478,13 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
478474

479475
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
480476
gpt_params default_params;
481-
if (!body["streaming"].is_null())
477+
if (!body["stream"].is_null())
482478
{
483-
llama.streaming = body["streaming"].get<bool>();
479+
llama.stream = body["stream"].get<bool>();
484480
}
485481
else
486482
{
487-
llama.streaming = false;
483+
llama.stream = false;
488484
}
489485
if (!body["n_predict"].is_null())
490486
{
@@ -675,8 +671,6 @@ int main(int argc, char **argv)
675671
llama_server_context llama;
676672
params.model = "ggml-model.bin";
677673

678-
std::string final_text;
679-
680674
if (server_params_parse(argc, argv, sparams, params) == false)
681675
{
682676
return 1;
@@ -693,98 +687,81 @@ int main(int argc, char **argv)
693687
svr.Get("/", [](const Request &, Response &res)
694688
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
695689

696-
svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
697-
{
698-
if(llama.params.embedding) {
699-
json data = {
700-
{"status", "error"},
701-
{"reason", "To use completion function, disable embedding mode"}};
702-
res.set_content(data.dump(), "application/json");
703-
res.status = 400;
704-
return;
705-
}
690+
svr.Post("/completion", [&llama](const Request &req, Response &res) {
691+
if (llama.params.embedding) {
692+
json data = {
693+
{"status", "error"},
694+
{"reason", "To use completion function, disable embedding mode"}};
695+
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
696+
"application/json");
697+
res.status = 400;
698+
return;
699+
}
706700

707-
llama.rewind();
708-
final_text = "";
701+
llama.rewind();
709702

710-
if(parse_options_completion(json::parse(req.body), llama, res) == false){
711-
return;
712-
}
703+
if (parse_options_completion(json::parse(req.body), llama, res) == false) {
704+
return;
705+
}
713706

714-
if (!llama.loadPrompt())
715-
{
716-
json data = {
717-
{"status", "error"},
718-
{"reason", "Context too long."}};
719-
res.set_content(data.dump(), "application/json");
720-
res.status = 400;
721-
return;
722-
}
707+
if (!llama.loadPrompt()) {
708+
json data = {{"status", "error"}, {"reason", "Context too long."}};
709+
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
710+
"application/json");
711+
res.status = 400;
712+
return;
713+
}
714+
715+
llama.beginCompletion();
716+
717+
if (!llama.stream) {
718+
while (llama.has_next_token) {
719+
llama.doCompletion();
720+
}
721+
722+
json data = {{"content", llama.generated_text},
723+
{"stop", true},
724+
{"model", llama.params.model_alias },
725+
{"tokens_predicted", llama.num_tokens_predicted},
726+
{"generation_settings", format_generation_settings(llama)},
727+
{"prompt", llama.params.prompt},
728+
{"stopping_word", llama.stopping_word}};
729+
return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
730+
} else {
731+
const auto chunked_content_provider = [&](size_t, DataSink &sink) {
732+
while (llama.has_next_token) {
733+
std::string token_text = llama.doCompletion();
723734

724-
llama.beginCompletion();
725-
if(llama.streaming)
726-
{
727-
res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/,
728-
DataSink& sink) {
729-
std::string final_text = "";
730-
// loop inference until finish completion
731-
while (llama.has_next_token) {
732-
std::string result = llama.doCompletion();
733735
json data;
734-
final_text += result;
735-
if (llama.has_next_token)
736-
{
737-
data = { {"content", result}, {"stop", false} };
738-
}
739-
else
740-
{
741-
// Generation is done, send extra information.
742-
data = { {"content", result},
743-
{"stop", true},
744-
{"tokens_predicted", llama.num_tokens_predicted},
745-
{"generation_settings", format_generation_settings(llama)},
746-
{"prompt", llama.params.prompt},
747-
{"stopping_word", llama.stopping_word},
748-
{"generated_text", final_text} };
736+
if (llama.has_next_token) {
737+
data = {{"content", token_text}, {"stop", false}};
738+
} else {
739+
// Generation is done, send extra information.
740+
data = {
741+
{"content", token_text},
742+
{"stop", true},
743+
{"model", llama.params.model_alias},
744+
{"tokens_predicted", llama.num_tokens_predicted},
745+
{"generation_settings", format_generation_settings(llama)},
746+
{"prompt", llama.params.prompt},
747+
{"stopping_word", llama.stopping_word},
748+
{"generated_text", llama.generated_text}};
749749
}
750750

751751
std::string str =
752-
"data: " + data.dump(4, ' ', false, json::error_handler_t::replace) +
753-
"\n\n";
752+
"data: " +
753+
data.dump(-1, ' ', false, json::error_handler_t::replace) +
754+
"\n\n";
754755
sink.write(str.data(), str.size());
755-
}
756-
757-
sink.done();
758-
return true;
759-
});
760756
}
761-
else
762-
{
763-
// loop inference until finish completion
764-
while (llama.has_next_token)
765-
{
766-
llama.doCompletion();
767-
}
768-
try
769-
{
770-
json data = {
771-
{"model", llama.params.model_alias },
772-
{"content", llama.generated_text },
773-
{"tokens_predicted", llama.num_tokens_predicted},
774-
{"generation_settings", format_generation_settings(llama)},
775-
{"prompt", llama.params.prompt},
776-
{"stopping_word", llama.stopping_word} };
777-
return res.set_content(data.dump(), "application/json");
778-
}
779-
catch (const json::exception &e)
780-
{
781-
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
782-
json data = {
783-
{"content", "Bad encoding token"},
784-
{"tokens_predicted", 0}};
785-
return res.set_content(data.dump(), "application/json");
786-
}
787-
} });
757+
758+
sink.done();
759+
return true;
760+
};
761+
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
762+
}
763+
});
764+
788765

789766
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
790767
{
@@ -811,7 +788,6 @@ int main(int argc, char **argv)
811788
return res.set_content(data.dump(), "application/json");
812789
});
813790

814-
815791
fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
816792

817793
if(params.embedding) {

0 commit comments

Comments
 (0)