Skip to content

Update main's interactive mode to use the chat handshake templates support already available in llama.cpp (and currently only used by server,...) #6795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ add_library(${TARGET} STATIC
train.cpp
ngram-cache.h
ngram-cache.cpp
chaton.hpp
)

if (BUILD_SHARED_LIBS)
Expand Down
69 changes: 69 additions & 0 deletions common/chaton.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once

/**
*
* Provides a simple and dumb helpers which help chat with llm chat/instruct models
* using the chat template expected by them.
*
* Normally used to tag system prompt and user messages.
* Currently used by example/main programs.
*
* This builds on the llama_chat_apply_template. When adding support for new chat templates
* remember to update llama_chat_apply_template_internal as well as llama_chat_reverse_prompt.
*
* example/main program uses this when --chaton TEMPLATE_ID is passed to it along with -i
* sample TEMPLATE_ID's include chatml, llama2, llama3, ...
*
*/

#include <vector>
#include <string>

#include "llama.h"
#include "log.h"

// Tag the passed message suitabley as expected by the specified chat handshake template
// and the role. If the specified template is not supported logic will return false.
inline bool llama_chat_apply_template_simple(
const std::string &tmpl,
const std::string &role,
const std::string &content,
std::string &dst,
bool add_ass) {
llama_chat_message msg = { role.c_str(), content.c_str() };
std::vector<char> buf(content.size() * 2); // This may under allot for small messages and over allot for large messages

int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size());
if (slen == -1) {
LOG_TEELN("WARN:%s:Unknown template [%s] requested", __func__, tmpl.c_str());
dst = "";
return false;
}
if ((size_t) slen > buf.size()) {
LOGLN("INFO:%s:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size());
buf.resize(slen);
slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size());
}

const std::string tagged_msg(buf.data(), slen);
LOGLN("INFO:%s:%s:%s", __func__, role.c_str(), tagged_msg.c_str());
dst = tagged_msg;
return true;
}

// return what should be the reverse prompt for the given template id
// ie possible end text tag(s) of specified model type's chat query response.
// Note that It adds these reverse prompts to any that may already exist in the passed vector.
inline bool llama_chat_reverse_prompt(std::string &template_id, std::vector<std::string> &rprompts) {
if (template_id == "chatml") {
rprompts.push_back("<|im_start|>user\n");
} else if (template_id == "llama2") {
rprompts.push_back("</s>");
} else if (template_id == "llama3") {
rprompts.push_back("<|eot_id|>");
} else {
LOG_TEELN("WARN:%s:Unknown template [%s] requested", __func__, template_id.c_str());
return false;
}
return true;
}
11 changes: 11 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.chatml = true;
return true;
}
if (arg == "--chaton") {
params.chaton = true;
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chaton_template_id = argv[i];
return true;
}
if (arg == "--infill") {
params.infill = true;
return true;
Expand Down Expand Up @@ -1378,6 +1387,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --version show version and build info\n");
printf(" -i, --interactive run in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" --chaton TEMPLATE_ID allow the interactive mode to apply the specified chat template before sending user input to model (you need to specify -i also)\n");
printf(" TEMPLATE_ID could be chatml, llama3, ...\n");
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ struct gpt_params {
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
bool chaton = false; // chaton mode (used to chat with models which have been trained for chat and or instruct operation)
std::string chaton_template_id = ""; // the internal chat template to use
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

Expand Down
54 changes: 41 additions & 13 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.h"
#include "chaton.hpp"

#include "console.h"
#include "llama.h"
Expand Down Expand Up @@ -251,11 +252,17 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd_inp;

if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
if (params.interactive_first || params.instruct || params.chatml || params.chaton || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt: %s\n", params.prompt.c_str());
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
if (params.chaton) {
if (!llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, params.prompt, false)) {
LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "system", params.prompt.c_str());
exit(2);
}
}
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else {
LOG("use session tokens\n");
Expand Down Expand Up @@ -333,7 +340,7 @@ int main(int argc, char ** argv) {
}

// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml || params.chaton) {
params.n_keep = (int)embd_inp.size();
} else {
params.n_keep += add_bos; // always keep the BOS token
Expand Down Expand Up @@ -363,6 +370,14 @@ int main(int argc, char ** argv) {
params.interactive_first = true;
params.antiprompt.emplace_back("<|im_start|>user\n");
}
// handle chaton mode, it adds on to any reverse prompt specified explicitly by the user
if (params.chaton) {
params.interactive_first = true;
if (!llama_chat_reverse_prompt(params.chaton_template_id, params.antiprompt)) {
LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatTemplateType:%s", __func__, params.chaton_template_id.c_str());
exit(1);
}
}

// enable interactive mode if interactive start is specified
if (params.interactive_first) {
Expand Down Expand Up @@ -817,7 +832,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");

if (params.instruct || params.chatml) {
if (params.instruct || params.chatml || params.chaton) {
printf("\n> ");
}

Expand Down Expand Up @@ -876,15 +891,27 @@ int main(int argc, char ** argv) {
process_escapes(buffer);
}

const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);

LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
std::vector<int> line_inp;
if (params.chaton) {
std::string f_chat;
if (!llama_chat_apply_template_simple(params.chaton_template_id, "user", buffer.c_str(), f_chat, true)) {
LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "user", params.prompt.c_str());
exit(2);
}
line_inp = ::llama_tokenize(ctx, f_chat, false, true);
LOG("formatted input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
} else {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);

LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
}

// instruct mode: insert response suffix
if (params.instruct) {
Expand Down Expand Up @@ -921,6 +948,7 @@ int main(int argc, char ** argv) {
}

// end of text token
// chaton expected to be used along with interactive argument, so not checking for chaton seperately
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n");
break;
Expand Down