Skip to content

Server: Use multi-task for embeddings endpoint #6001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 27 additions & 49 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8");
});

svr->set_logger(log_server_request);
Expand Down Expand Up @@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) {
const json body = json::parse(req.body);
bool is_openai = false;

// an input prompt can string or a list of tokens (integer)
std::vector<json> prompts;
// an input prompt can be a string or a list of tokens (integer)
json prompt;
if (body.count("input") != 0) {
is_openai = true;
if (body["input"].is_array()) {
// support multiple prompts
for (const json & elem : body["input"]) {
prompts.push_back(elem);
}
} else {
// single input prompt
prompts.push_back(body["input"]);
}
prompt = body["input"];
} else if (body.count("content") != 0) {
// only support single prompt here
std::string content = body["content"];
prompts.push_back(content);
// with "content", we only support single prompt
prompt = std::vector<std::string>{body["content"]};
} else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}

// process all prompts
json responses = json::array();
for (auto & prompt : prompts) {
// TODO @ngxson : maybe support multitask for this endpoint?
// create and queue the task
// create and queue the task
json responses;
{
const int id_task = ctx_server.queue_tasks.get_new_id();

ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);

// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (!result.error) {
// append to the responses
responses.push_back(result.data);
if (result.data.count("results")) {
// result for multi-task
responses = result.data["results"];
} else {
// result for single task
responses = std::vector<json>{result.data};
}
} else {
// error received, ignore everything else
res_error(res, result.data);
Expand All @@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root;
if (is_openai) {
json res_oai = json::array();
int i = 0;
for (auto & elem : responses) {
res_oai.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}
root = format_embeddings_response_oaicompat(body, res_oai);
} else {
root = responses[0];
}
json root = is_openai
? format_embeddings_response_oaicompat(body, responses)
: responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8");
};

auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};

//
// Router
//
Expand All @@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) {
}

// using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};

svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
return res.set_content("", "application/json; charset=utf-8");
});
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
Expand Down
12 changes: 11 additions & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,24 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
}

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array();
int i = 0;
for (auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}

json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", 0},
{"total_tokens", 0}
}},
{"data", embeddings}
{"data", data}
};

return res;
Expand Down