Skip to content

feat: rendering chat_template #1814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions engine/cli/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {

return data_length;
}

} // namespace

void ChatCompletionCmd::Exec(const std::string& host, int port,
Expand Down Expand Up @@ -103,7 +102,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
return;
}

std::string url = "http://" + address + "/v1/chat/completions";
auto url = "http://" + address + "/v1/chat/completions";
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);

Expand Down Expand Up @@ -151,18 +150,18 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
json_data["model"] = model_handle;
json_data["stream"] = true;

std::string json_payload = json_data.toStyledString();

curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
auto json_str = json_data.toStyledString();
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_str.length());
curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L);

std::string ai_chat;
StreamingCallback callback;
callback.ai_chat = &ai_chat;

curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &callback);

CURLcode res = curl_easy_perform(curl);
auto res = curl_easy_perform(curl);

if (res != CURLE_OK) {
CLI_LOG("CURL request failed: " << curl_easy_strerror(res));
Expand Down
29 changes: 29 additions & 0 deletions engine/common/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <sstream>
#include "common/tokenizer.h"

struct ModelMetadata {
uint32_t version;
uint64_t tensor_count;
uint64_t metadata_kv_count;
std::shared_ptr<Tokenizer> tokenizer;

std::string ToString() const {
std::ostringstream ss;
ss << "ModelMetadata {\n"
<< "version: " << version << "\n"
<< "tensor_count: " << tensor_count << "\n"
<< "metadata_kv_count: " << metadata_kv_count << "\n"
<< "tokenizer: ";

if (tokenizer) {
ss << "\n" << tokenizer->ToString();
} else {
ss << "null";
}

ss << "\n}";
return ss.str();
}
};
72 changes: 72 additions & 0 deletions engine/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include <sstream>
#include <string>

struct Tokenizer {
std::string eos_token = "";
bool add_eos_token = true;

std::string bos_token = "";
bool add_bos_token = true;

std::string unknown_token = "";
std::string padding_token = "";

std::string chat_template = "";

bool add_generation_prompt = true;

// Helper function for common fields
std::string BaseToString() const {
std::ostringstream ss;
ss << "eos_token: \"" << eos_token << "\"\n"
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
<< "bos_token: \"" << bos_token << "\"\n"
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
<< "unknown_token: \"" << unknown_token << "\"\n"
<< "padding_token: \"" << padding_token << "\"\n"
<< "chat_template: \"" << chat_template << "\"\n"
<< "add_generation_prompt: "
<< (add_generation_prompt ? "true" : "false") << "\"";
return ss.str();
}

virtual ~Tokenizer() = default;

virtual std::string ToString() = 0;
};

struct GgufTokenizer : public Tokenizer {
std::string pre = "";

~GgufTokenizer() override = default;

std::string ToString() override {
std::ostringstream ss;
ss << "GgufTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "pre: \"" << pre << "\"\n";
ss << "}";
return ss.str();
}
};

struct SafeTensorTokenizer : public Tokenizer {
bool add_prefix_space = true;

~SafeTensorTokenizer() = default;

std::string ToString() override {
std::ostringstream ss;
ss << "SafeTensorTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
ss << "}";
return ss.str();
}
};
17 changes: 5 additions & 12 deletions engine/controllers/files.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp =
cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
} else {
if (!msg_res->rel_path.has_value()) {
Expand All @@ -243,10 +241,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(content_res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(
std::move(content_res.value()));
callback(resp);
}
}
Expand All @@ -261,9 +257,6 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
}
9 changes: 8 additions & 1 deletion engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/function_calling/common.h"
#include "utils/http_util.h"

using namespace inferences;

Expand All @@ -27,6 +26,14 @@ void server::ChatCompletion(
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_DEBUG << "Start chat completion";
auto json_body = req->getJsonObject();
if (json_body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}
bool is_stream = (*json_body).get("stream", false).asBool();
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
Expand Down
1 change: 1 addition & 0 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
auto model_src_svc = std::make_shared<services::ModelSourceService>();
auto model_service = std::make_shared<ModelService>(
download_service, inference_svc, engine_service);
inference_svc->SetModelService(model_service);

auto file_watcher_srv = std::make_shared<FileWatcherService>(
model_dir_path.string(), model_service);
Expand Down
20 changes: 8 additions & 12 deletions engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand All @@ -17,7 +16,6 @@
#include "utils/cpuid/cpu_info.h"
#include "utils/dylib.h"
#include "utils/dylib_path_manager.h"
#include "utils/engine_constants.h"
#include "utils/github_release_utils.h"
#include "utils/result.hpp"
#include "utils/system_info_utils.h"
Expand Down Expand Up @@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
struct EngineInfo {
std::unique_ptr<cortex_cpp::dylib> dl;
EngineV engine;
#if defined(_WIN32)
DLL_DIRECTORY_COOKIE cookie;
DLL_DIRECTORY_COOKIE cuda_cookie;
#endif
};

std::mutex engines_mutex_;
Expand Down Expand Up @@ -105,21 +99,23 @@ class EngineService : public EngineServiceI {

cpp::result<DefaultEngineVariant, std::string> SetDefaultEngineVariant(
const std::string& engine, const std::string& version,
const std::string& variant);
const std::string& variant) override;

cpp::result<DefaultEngineVariant, std::string> GetDefaultEngineVariant(
const std::string& engine);
const std::string& engine) override;

cpp::result<std::vector<EngineVariantResponse>, std::string>
GetInstalledEngineVariants(const std::string& engine) const;
GetInstalledEngineVariants(const std::string& engine) const override;

cpp::result<EngineV, std::string> GetLoadedEngine(
const std::string& engine_name);

std::vector<EngineV> GetLoadedEngines();

cpp::result<void, std::string> LoadEngine(const std::string& engine_name);
cpp::result<void, std::string> UnloadEngine(const std::string& engine_name);
cpp::result<void, std::string> LoadEngine(
const std::string& engine_name) override;
cpp::result<void, std::string> UnloadEngine(
const std::string& engine_name) override;

cpp::result<github_release_utils::GitHubRelease, std::string>
GetLatestEngineVersion(const std::string& engine) const;
Expand All @@ -137,7 +133,7 @@ class EngineService : public EngineServiceI {

cpp::result<cortex::db::EngineEntry, std::string> GetEngineByNameAndVariant(
const std::string& engine_name,
const std::optional<std::string> variant = std::nullopt);
const std::optional<std::string> variant = std::nullopt) override;

cpp::result<cortex::db::EngineEntry, std::string> UpsertEngine(
const std::string& engine_name, const std::string& type,
Expand Down
42 changes: 41 additions & 1 deletion engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <drogon/HttpTypes.h>
#include "utils/engine_constants.h"
#include "utils/function_calling/common.h"
#include "utils/jinja_utils.h"

namespace services {
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
Expand All @@ -24,6 +25,45 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

{
auto model_id = json_body->get("model", "").asString();
if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json,
tokenizer->bos_token, tokenizer->eos_token,
tokenizer->add_bos_token, tokenizer->add_eos_token,
tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
}
}
}

CTL_INF("Json body inference: " + json_body->toStyledString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
Expand Down Expand Up @@ -297,4 +337,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
}
return true;
}
} // namespace services
} // namespace services
9 changes: 8 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#include <mutex>
#include <queue>
#include "services/engine_service.h"
#include "services/model_service.h"
#include "utils/result.hpp"
#include "extensions/remote-engine/remote_engine.h"

namespace services {

// Status and result
using InferResult = std::pair<Json::Value, Json::Value>;

Expand Down Expand Up @@ -58,7 +60,12 @@ class InferenceService {
bool HasFieldInReq(std::shared_ptr<Json::Value> json_body,
const std::string& field);

void SetModelService(std::shared_ptr<ModelService> model_service) {
model_service_ = model_service;
}

private:
std::shared_ptr<EngineService> engine_service_;
std::weak_ptr<ModelService> model_service_;
};
} // namespace services
Loading