Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/chat-auto-parser-helpers.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "chat-auto-parser.h"
#include "peg-parser.h"

#include <functional>
#include <optional>
#include <string>
Expand Down
1 change: 1 addition & 0 deletions common/chat-auto-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common.h"
#include "jinja/caps.h"
#include "peg-parser.h"
#include "nlohmann/json.hpp"

#include <chrono>
#include <optional>
Expand Down
36 changes: 22 additions & 14 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "jinja/caps.h"
#include "peg-parser.h"

#include "nlohmann/json.hpp"

#include <cstdio>
#include <cstdlib>
#include <ctime>
Expand Down Expand Up @@ -760,12 +762,12 @@ static void foreach_parameter(const json &
}
}

std::string common_chat_template_direct_apply(
static std::string common_chat_template_direct_apply_impl(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override,
const std::optional<json> & tools_override,
const std::optional<json> & additional_context) {
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt) {
jinja::context ctx(tmpl.source());

nlohmann::ordered_json inp = nlohmann::ordered_json{
Expand Down Expand Up @@ -812,6 +814,12 @@ std::string common_chat_template_direct_apply(
return result;
}

std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
}

static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
Expand Down Expand Up @@ -862,7 +870,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
"[THINK]",
Expand Down Expand Up @@ -945,7 +953,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
adjusted_messages.push_back(msg);
}

auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);

// Check if we need to replace the return token with end token during
// inference and without generation prompt. For more details see:
Expand Down Expand Up @@ -1067,7 +1075,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const autoparser::generation_params & inputs) {
common_chat_params data;

data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
">>>all",
Expand Down Expand Up @@ -1161,7 +1169,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::generation_params & inputs) {
common_chat_params data;

data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
Expand Down Expand Up @@ -1284,7 +1292,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
const autoparser::generation_params & inputs) {
common_chat_params data;

data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
Expand Down Expand Up @@ -1363,7 +1371,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
const autoparser::generation_params & inputs) {
common_chat_params data;

data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
Expand Down Expand Up @@ -1434,7 +1442,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(

common_chat_params data;

data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = false;
data.preserved_tokens = {
Expand Down Expand Up @@ -1669,9 +1677,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
}

params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;

Expand Down Expand Up @@ -1705,7 +1713,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
Expand Down
55 changes: 9 additions & 46 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#pragma once

#include "common.h"
#include "jinja/parser.h"
#include "nlohmann/json_fwd.hpp"
#include "peg-parser.h"
#include "jinja/parser.h"
#include "jinja/runtime.h"
#include "jinja/caps.h"
#include "nlohmann/json.hpp"

#include "nlohmann/json_fwd.hpp"

#include <chrono>
#include <functional>
Expand All @@ -19,8 +19,6 @@
using chat_template_caps = jinja::caps;
using json = nlohmann::ordered_json;

#include <nlohmann/json_fwd.hpp>

struct common_chat_templates;

namespace autoparser {
Expand Down Expand Up @@ -75,41 +73,9 @@ struct common_chat_template {
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }

// TODO: this is ugly, refactor it somehow
json add_system(const json & messages, const std::string & system_prompt) const {
GGML_ASSERT(messages.is_array());
auto msgs_copy = messages;
if (!caps.supports_system_role) {
if (msgs_copy.empty()) {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "user"},
{"content", system_prompt}
});
} else {
auto & first_msg = msgs_copy[0];
if (!first_msg.contains("content")) {
first_msg["content"] = "";
}
first_msg["content"] = system_prompt + "\n\n"
+ first_msg["content"].get<std::string>();
}
} else {
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "system"},
{"content", system_prompt}
});
} else if (msgs_copy[0].at("role") == "system") {
msgs_copy[0]["content"] = system_prompt;
}
}
return msgs_copy;
}

Comment on lines -78 to -108
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't seem to be used anywhere, so I removed it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Effectively refactored into oblivion

chat_template_caps original_caps() const {
return caps;
}

};

struct common_chat_msg {
Expand Down Expand Up @@ -256,8 +222,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");

bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");

struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
Expand All @@ -274,9 +240,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
bool use_jinja,
const std::map<std::string, std::string> & chat_template_kwargs);

const char * common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
const char * common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);

// used by arg and server
const char * common_reasoning_format_name(common_reasoning_format format);
Expand All @@ -302,7 +268,4 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem

std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt);
const autoparser::generation_params & inputs);
Loading