Skip to content

Commit 95cacf7

Browse files
committed
chat: try minja instead of Jinja2Cpp
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
1 parent 6a8a840 commit 95cacf7

File tree

13 files changed

+135
-273
lines changed

13 files changed

+135
-273
lines changed

.gitmodules

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
[submodule "gpt4all-chat/deps/QXlsx"]
1818
path = gpt4all-chat/deps/QXlsx
1919
url = https://github.com/nomic-ai/QXlsx.git
20-
[submodule "gpt4all-chat/deps/Jinja2Cpp"]
21-
path = gpt4all-chat/deps/Jinja2Cpp
22-
url = https://github.com/nomic-ai/jinja2cpp.git
23-
[submodule "gpt4all-chat/deps/rapidjson"]
24-
path = gpt4all-chat/deps/rapidjson
25-
url = https://github.com/nomic-ai/rapidjson.git
20+
[submodule "gpt4all-chat/deps/minja"]
21+
path = gpt4all-chat/deps/minja
22+
url = https://github.com/nomic-ai/minja.git
23+
[submodule "gpt4all-chat/deps/json"]
24+
path = gpt4all-chat/deps/json
25+
url = https://github.com/nlohmann/json.git

gpt4all-chat/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
5151
set(CMAKE_CXX_STANDARD 23)
5252
set(CMAKE_CXX_STANDARD_REQUIRED ON)
5353
if (MSVC)
54-
# Enable accurate __cplusplus macro to fix errors in Jinja2Cpp
54+
# Enable accurate __cplusplus macro
5555
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:__cplusplus>)
5656
endif()
5757

@@ -437,7 +437,10 @@ else()
437437
target_link_libraries(chat PRIVATE pdfium)
438438
endif()
439439
target_link_libraries(chat
440-
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx jinja2cpp)
440+
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx)
441+
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/json/include)
442+
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/json/include/nlohmann)
443+
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/minja/include)
441444

442445
if (APPLE)
443446
target_link_libraries(chat PRIVATE ${COCOA_LIBRARY})

gpt4all-chat/deps/CMakeLists.txt

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,6 @@ add_subdirectory(DuckX)
1515
set(QT_VERSION_MAJOR 6)
1616
add_subdirectory(QXlsx/QXlsx)
1717

18-
# forked dependency of Jinja2Cpp
19-
set(RAPIDJSON_BUILD_DOC OFF)
20-
set(RAPIDJSON_BUILD_EXAMPLES OFF)
21-
set(RAPIDJSON_BUILD_TESTS OFF)
22-
set(RAPIDJSON_ENABLE_INSTRUMENTATION_OPT OFF)
23-
add_subdirectory(rapidjson)
24-
25-
add_subdirectory(Jinja2Cpp)
26-
2718
if (NOT GPT4ALL_USING_QTPDF)
2819
# If we do not use QtPDF, we need to get PDFium.
2920
set(GPT4ALL_PDFIUM_TAG "chromium/6954")

gpt4all-chat/deps/Jinja2Cpp

Lines changed: 0 additions & 1 deletion
This file was deleted.

gpt4all-chat/deps/json

Submodule json added at 606b634

gpt4all-chat/deps/minja

Submodule minja added at 491f5cb

gpt4all-chat/deps/rapidjson

Lines changed: 0 additions & 1 deletion
This file was deleted.

gpt4all-chat/src/chatllm.cpp

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@
1212
#include "toolcallparser.h"
1313

1414
#include <fmt/format.h>
15-
16-
#include <jinja2cpp/error_info.h>
17-
#include <jinja2cpp/template.h>
18-
#include <jinja2cpp/template_env.h>
19-
#include <jinja2cpp/user_callable.h>
20-
#include <jinja2cpp/value.h>
15+
#include <minja/minja.hpp>
16+
#include <nlohmann/json.hpp>
2117

2218
#include <QDataStream>
2319
#include <QDebug>
@@ -60,44 +56,31 @@
6056
using namespace Qt::Literals::StringLiterals;
6157
using namespace ToolEnums;
6258
namespace ranges = std::ranges;
59+
using json = nlohmann::ordered_json;
6360

6461
//#define DEBUG
6562
//#define DEBUG_MODEL_LOADING
6663

6764
// NOTE: not threadsafe
68-
static jinja2::TemplateEnv *jinjaEnv()
65+
static const std::shared_ptr<minja::Context> &jinjaEnv()
6966
{
70-
static std::optional<jinja2::TemplateEnv> environment;
67+
static std::shared_ptr<minja::Context> environment;
7168
if (!environment) {
72-
auto &env = environment.emplace();
73-
auto &settings = env.GetSettings();
74-
settings.trimBlocks = true;
75-
settings.lstripBlocks = true;
76-
env.AddGlobal("raise_exception", jinja2::UserCallable(
77-
/*callable*/ [](auto &params) -> jinja2::Value {
78-
auto messageArg = params.args.find("message");
79-
if (messageArg == params.args.end() || !messageArg->second.isString())
80-
throw std::runtime_error("'message' argument to raise_exception() must be a string");
81-
throw std::runtime_error(fmt::format("Jinja template error: {}", messageArg->second.asString()));
82-
},
83-
/*argsInfo*/ { jinja2::ArgInfo("message", /*isMandatory*/ true) }
84-
));
85-
env.AddGlobal("strftime_now", jinja2::UserCallable(
86-
/*callable*/ [](auto &params) -> jinja2::Value {
69+
environment = minja::Context::builtins();
70+
environment->set("strftime_now", minja::simple_function(
71+
"strftime_now", { "format" },
72+
[](const std::shared_ptr<minja::Context> &, minja::Value &args) -> minja::Value {
73+
auto format = args.at("format").get<std::string>();
8774
using Clock = std::chrono::system_clock;
88-
auto formatArg = params.args.find("format");
89-
if (formatArg == params.args.end() || !formatArg->second.isString())
90-
throw std::runtime_error("'format' argument to strftime_now() must be a string");
9175
time_t nowUnix = Clock::to_time_t(Clock::now());
9276
auto localDate = *std::localtime(&nowUnix);
9377
std::ostringstream ss;
94-
ss << std::put_time(&localDate, formatArg->second.asString().c_str());
78+
ss << std::put_time(&localDate, format.c_str());
9579
return ss.str();
96-
},
97-
/*argsInfo*/ { jinja2::ArgInfo("format", /*isMandatory*/ true) }
80+
}
9881
));
9982
}
100-
return &*environment;
83+
return environment;
10184
}
10285

10386
class LLModelStore {
@@ -757,19 +740,18 @@ static uint parseJinjaTemplateVersion(QStringView tmpl)
757740
return 0;
758741
}
759742

760-
static auto loadJinjaTemplate(
761-
std::optional<jinja2::Template> &tmpl /*out*/, const std::string &source
762-
) -> jinja2::Result<void>
743+
static std::shared_ptr<minja::TemplateNode> loadJinjaTemplate(const std::string &source)
763744
{
764-
tmpl.emplace(jinjaEnv());
765-
return tmpl->Load(source);
745+
return minja::Parser::parse(source, { .trim_blocks = true, .lstrip_blocks = true, .keep_trailing_newline = false });
766746
}
767747

768748
std::optional<std::string> ChatLLM::checkJinjaTemplateError(const std::string &source)
769749
{
770-
std::optional<jinja2::Template> tmpl;
771-
if (auto res = loadJinjaTemplate(tmpl, source); !res)
772-
return res.error().ToString();
750+
try {
751+
loadJinjaTemplate(source);
752+
} catch (const std::runtime_error &e) {
753+
return e.what();
754+
}
773755
return std::nullopt;
774756
}
775757

@@ -801,13 +783,13 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
801783
uint version = parseJinjaTemplateVersion(chatTemplate);
802784

803785
auto makeMap = [version](const MessageItem &item) {
804-
return jinja2::GenericMap([msg = std::make_shared<JinjaMessage>(version, item)] { return msg.get(); });
786+
return JinjaMessage(version, item).AsJson();
805787
};
806788

807789
std::unique_ptr<MessageItem> systemItem;
808790
bool useSystem = !isAllSpace(systemMessage);
809791

810-
jinja2::ValuesList messages;
792+
json::array_t messages;
811793
messages.reserve(useSystem + items.size());
812794
if (useSystem) {
813795
systemItem = std::make_unique<MessageItem>(MessageItem::Type::System, systemMessage.toUtf8());
@@ -816,27 +798,29 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
816798
for (auto &item : items)
817799
messages.emplace_back(makeMap(item));
818800

819-
jinja2::ValuesList toolList;
801+
json::array_t toolList;
820802
const int toolCount = ToolModel::globalInstance()->count();
821803
for (int i = 0; i < toolCount; ++i) {
822804
Tool *t = ToolModel::globalInstance()->get(i);
823805
toolList.push_back(t->jinjaValue());
824806
}
825807

826-
jinja2::ValuesMap params {
808+
json::object_t params {
827809
{ "messages", std::move(messages) },
828810
{ "add_generation_prompt", true },
829811
{ "toolList", toolList },
830812
};
831813
for (auto &[name, token] : model->specialTokens())
832814
params.emplace(std::move(name), std::move(token));
833815

834-
std::optional<jinja2::Template> tmpl;
835-
auto maybeRendered = loadJinjaTemplate(tmpl, chatTemplate.toStdString())
836-
.and_then([&] { return tmpl->RenderAsString(params); });
837-
if (!maybeRendered)
838-
throw std::runtime_error(fmt::format("Failed to parse chat template: {}", maybeRendered.error().ToString()));
839-
return *maybeRendered;
816+
try {
817+
auto tmpl = loadJinjaTemplate(chatTemplate.toStdString());
818+
auto context = minja::Context::make(minja::Value(std::move(params)), jinjaEnv());
819+
return tmpl->render(context);
820+
} catch (const std::runtime_error &e) {
821+
throw std::runtime_error(fmt::format("Failed to parse chat template: {}", e.what()));
822+
}
823+
Q_UNREACHABLE();
840824
}
841825

842826
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,

gpt4all-chat/src/jinja_helpers.cpp

Lines changed: 59 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,111 +7,75 @@
77
#include <QString>
88
#include <QUrl>
99

10-
#include <memory>
10+
#include <iterator>
11+
#include <map>
12+
#include <ranges>
1113
#include <vector>
1214

13-
using namespace std::literals::string_view_literals;
15+
namespace views = std::views;
16+
using json = nlohmann::ordered_json;
1417

1518

16-
JinjaResultInfo::~JinjaResultInfo() = default;
17-
18-
const JinjaFieldMap<ResultInfo> JinjaResultInfo::s_fields = {
19-
{ "collection", [](auto &s) { return s.collection.toStdString(); } },
20-
{ "path", [](auto &s) { return s.path .toStdString(); } },
21-
{ "file", [](auto &s) { return s.file .toStdString(); } },
22-
{ "title", [](auto &s) { return s.title .toStdString(); } },
23-
{ "author", [](auto &s) { return s.author .toStdString(); } },
24-
{ "date", [](auto &s) { return s.date .toStdString(); } },
25-
{ "text", [](auto &s) { return s.text .toStdString(); } },
26-
{ "page", [](auto &s) { return s.page; } },
27-
{ "file_uri", [](auto &s) { return s.fileUri() .toStdString(); } },
28-
};
29-
30-
JinjaPromptAttachment::~JinjaPromptAttachment() = default;
31-
32-
const JinjaFieldMap<PromptAttachment> JinjaPromptAttachment::s_fields = {
33-
{ "url", [](auto &s) { return s.url.toString() .toStdString(); } },
34-
{ "file", [](auto &s) { return s.file() .toStdString(); } },
35-
{ "processed_content", [](auto &s) { return s.processedContent().toStdString(); } },
36-
};
37-
38-
std::vector<std::string> JinjaMessage::GetKeys() const
19+
json::object_t JinjaResultInfo::AsJson() const
3920
{
40-
std::vector<std::string> result;
41-
auto &keys = this->keys();
42-
result.reserve(keys.size());
43-
result.assign(keys.begin(), keys.end());
44-
return result;
21+
return {
22+
{ "collection", m_source->collection.toStdString() },
23+
{ "path", m_source->path .toStdString() },
24+
{ "file", m_source->file .toStdString() },
25+
{ "title", m_source->title .toStdString() },
26+
{ "author", m_source->author .toStdString() },
27+
{ "date", m_source->date .toStdString() },
28+
{ "text", m_source->text .toStdString() },
29+
{ "page", m_source->page },
30+
{ "file_uri", m_source->fileUri() .toStdString() },
31+
};
4532
}
4633

47-
auto JinjaMessage::keys() const -> const std::unordered_set<std::string_view> &
34+
json::object_t JinjaPromptAttachment::AsJson() const
4835
{
49-
static const std::unordered_set<std::string_view> baseKeys
50-
{ "role", "content" };
51-
static const std::unordered_set<std::string_view> userKeys
52-
{ "role", "content", "sources", "prompt_attachments" };
53-
switch (m_item->type()) {
54-
using enum MessageItem::Type;
55-
case System:
56-
case Response:
57-
case ToolResponse:
58-
return baseKeys;
59-
case Prompt:
60-
return userKeys;
61-
break;
62-
}
63-
Q_UNREACHABLE();
36+
return {
37+
{ "url", m_attachment->url.toString() .toStdString() },
38+
{ "file", m_attachment->file() .toStdString() },
39+
{ "processed_content", m_attachment->processedContent().toStdString() },
40+
};
6441
}
6542

66-
bool operator==(const JinjaMessage &a, const JinjaMessage &b)
43+
json::object_t JinjaMessage::AsJson() const
6744
{
68-
if (a.m_item == b.m_item)
69-
return true;
70-
const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item);
71-
auto type = ia.type();
72-
if (type != ib.type() || ia.content() != ib.content())
73-
return false;
74-
75-
switch (type) {
76-
using enum MessageItem::Type;
77-
case System:
78-
case Response:
79-
case ToolResponse:
80-
return true;
81-
case Prompt:
82-
return ia.sources() == ib.sources() && ia.promptAttachments() == ib.promptAttachments();
83-
break;
84-
}
85-
Q_UNREACHABLE();
86-
}
87-
88-
const JinjaFieldMap<JinjaMessage> JinjaMessage::s_fields = {
89-
{ "role", [](auto &m) {
90-
switch (m.item().type()) {
45+
json::object_t obj;
46+
{
47+
json::string_t role;
48+
switch (m_item->type()) {
9149
using enum MessageItem::Type;
92-
case System: return "system"sv;
93-
case Prompt: return "user"sv;
94-
case Response: return "assistant"sv;
95-
case ToolResponse: return "tool"sv;
96-
break;
50+
case System: role = "system"; break;
51+
case Prompt: role = "user"; break;
52+
case Response: role = "assistant"; break;
53+
case ToolResponse: role = "tool"; break;
54+
}
55+
obj.emplace_back("role", std::move(role));
56+
}
57+
{
58+
QString content;
59+
if (m_version == 0 && m_item->type() == MessageItem::Type::Prompt) {
60+
content = m_item->bakedPrompt();
61+
} else {
62+
content = m_item->content();
63+
}
64+
obj.emplace_back("content", content.toStdString());
65+
}
66+
if (m_item->type() == MessageItem::Type::Prompt) {
67+
{
68+
auto sources = m_item->sources() | views::transform([](auto &r) {
69+
return JinjaResultInfo(r).AsJson();
70+
});
71+
obj.emplace("sources", json::array_t(sources.begin(), sources.end()));
72+
}
73+
{
74+
auto attachments = m_item->promptAttachments() | views::transform([](auto &pa) {
75+
return JinjaPromptAttachment(pa).AsJson();
76+
});
77+
obj.emplace("prompt_attachments", json::array_t(attachments.begin(), attachments.end()));
9778
}
98-
Q_UNREACHABLE();
99-
} },
100-
{ "content", [](auto &m) {
101-
if (m.version() == 0 && m.item().type() == MessageItem::Type::Prompt)
102-
return m.item().bakedPrompt().toStdString();
103-
return m.item().content().toStdString();
104-
} },
105-
{ "sources", [](auto &m) {
106-
auto sources = m.item().sources() | views::transform([](auto &r) {
107-
return jinja2::GenericMap([map = std::make_shared<JinjaResultInfo>(r)] { return map.get(); });
108-
});
109-
return jinja2::ValuesList(sources.begin(), sources.end());
110-
} },
111-
{ "prompt_attachments", [](auto &m) {
112-
auto attachments = m.item().promptAttachments() | views::transform([](auto &pa) {
113-
return jinja2::GenericMap([map = std::make_shared<JinjaPromptAttachment>(pa)] { return map.get(); });
114-
});
115-
return jinja2::ValuesList(attachments.begin(), attachments.end());
116-
} },
117-
};
79+
}
80+
return obj;
81+
}

0 commit comments

Comments
 (0)