Skip to content

Commit

Permalink
ChatON: Test ChatParts in chat-template-apply
Browse files Browse the repository at this point in the history
  • Loading branch information
hanishkvc committed Apr 23, 2024
1 parent 2824815 commit 9626779
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions common/chaton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,10 @@ inline std::string chaton_tmpl_apply_single(const std::string &tmpl, const std::
// then 1st user message will have user-prefix only if systemuser-1st-user-has-prefix is true
// NOTE: This currently doesnt return about which parts of the tagged message contain tags and which parts the user message
inline std::string chaton_tmpl_apply(const std::string &tmpl, const std::vector<llama_chat_message> &msgs) {
ChatParts cp = {};
std::stringstream ss;
ss << conMeta[tmpl][K_GLOBAL][K_BEGIN];
cp.add_part(ChatParts::S, conMeta[tmpl][K_GLOBAL][K_BEGIN]);
int cntSystem = 0;
int cntUser = 0;
int cntOthers = 0;
Expand All @@ -199,33 +201,51 @@ inline std::string chaton_tmpl_apply(const std::string &tmpl, const std::vector<
std::string begin = "";
try {
begin = conMeta[tmpl][role][K_BEGIN];
cp.add_part(ChatParts::S, begin);
} catch (json::exception &err) {

}
auto prefix = conMeta[tmpl][role][K_PREFIX];
if (role == K_SYSTEM) {
cntSystem += 1;
ss << begin << prefix;
cp.add_part(ChatParts::S, begin);
cp.add_part(ChatParts::S, prefix);
} else if (role == K_USER) {
cntUser += 1;
if ((cntSystem == 1) && (cntUser == 1)) {
if (conMeta[tmpl][K_SYSTEMUSER_1ST_USER_HAS_BEGIN]) {
ss << begin;
cp.add_part(ChatParts::S, begin);
}
if (conMeta[tmpl][K_SYSTEMUSER_1ST_USER_HAS_PREFIX]) {
ss << prefix;
cp.add_part(ChatParts::S, prefix);
}
} else {
ss << begin << prefix;
cp.add_part(ChatParts::S, begin);
cp.add_part(ChatParts::S, prefix);
}
} else {
cntOthers += 1;
ss << begin << prefix;
cp.add_part(ChatParts::S, begin);
cp.add_part(ChatParts::S, prefix);
}
ss << content << conMeta[tmpl][role][K_SUFFIX];
cp.add_part(ChatParts::N, content);
cp.add_part(ChatParts::S, conMeta[tmpl][role][K_SUFFIX]);
}
ss << conMeta[tmpl][K_GLOBAL][K_END];
cp.add_part(ChatParts::S, conMeta[tmpl][K_GLOBAL][K_END]);
cp.dump();
std::string taggedMsgs = ss.str();
std::string cpStr = cp.str();
if (taggedMsgs != cpStr) {
LOG_TEELN("DBUG:%s:Mismatch between CP[%s] and SS[%s]", __func__, cpStr.c_str(), taggedMsgs.c_str());
exit(2);
}
LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), taggedMsgs.c_str());
LOGLN("DBUG:%s:%s:CntSys[%d]:CntUsr[%d]:CntOthers[%d]", __func__, tmpl.c_str(), cntSystem, cntUser, cntOthers);
return taggedMsgs;
Expand Down

0 comments on commit 9626779

Please sign in to comment.