@@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
234234
235235 std::vector<llama_token> embd_inp;
236236
237- if (params.interactive_first || params.instruct || !params.prompt .empty () || session_tokens.empty ()) {
237+ if (params.interactive_first || params.instruct || params. chatml || !params.prompt .empty () || session_tokens.empty ()) {
238238 LOG (" tokenize the prompt\n " );
239+ if (params.chatml ) {
240+ params.prompt = " <|im_start|>system\n " + params.prompt + " <|im_end|>" ;
241+ }
239242 embd_inp = ::llama_tokenize (ctx, params.prompt , add_bos, true );
240243 } else {
241244 LOG (" use session tokens\n " );
@@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
313316 }
314317
315318 // number of tokens to keep when resetting context
316- if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct ) {
319+ if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct || params. chatml ) {
317320 params.n_keep = (int )embd_inp.size ();
318321 }
319322
@@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
324327 LOG (" inp_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_pfx).c_str ());
325328 LOG (" inp_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_sfx).c_str ());
326329
330+ // chatml prefix & suffix
331+ const auto cml_pfx = ::llama_tokenize (ctx, " \n <|im_start|>user\n " , add_bos, true );
332+ const auto cml_sfx = ::llama_tokenize (ctx, " <|im_end|>\n <|im_start|>assistant\n " , false , true );
333+
334+ LOG (" cml_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, cml_pfx).c_str ());
335+ LOG (" cml_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, cml_sfx).c_str ());
336+
327337 // in instruct mode, we inject a prefix and a suffix to each input by the user
328338 if (params.instruct ) {
329339 params.interactive_first = true ;
330340 params.antiprompt .push_back (" ### Instruction:\n\n " );
331341 }
342+ // similar for chatml mode
343+ else if (params.chatml ) {
344+ params.interactive_first = true ;
345+ params.antiprompt .push_back (" <|im_start|>user\n " );
346+ }
332347
333348 // enable interactive mode if interactive start is specified
334349 if (params.interactive_first ) {
@@ -705,15 +720,15 @@ int main(int argc, char ** argv) {
705720
706721 is_interacting = true ;
707722 printf (" \n " );
708- } else if (params.instruct ) {
723+ } else if (params.instruct || params. chatml ) {
709724 is_interacting = true ;
710725 }
711726 }
712727
713728 if (n_past > 0 && is_interacting) {
714729 LOG (" waiting for user input\n " );
715730
716- if (params.instruct ) {
731+ if (params.instruct || params. chatml ) {
717732 printf (" \n > " );
718733 }
719734
@@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
760775 n_consumed = embd_inp.size ();
761776 embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
762777 }
778+ // chatml mode: insert user chat prefix
779+ if (params.chatml && !is_antiprompt) {
780+ LOG (" inserting chatml prefix\n " );
781+ n_consumed = embd_inp.size ();
782+ embd_inp.insert (embd_inp.end (), cml_pfx.begin (), cml_pfx.end ());
783+ }
763784 if (params.escape ) {
764785 process_escapes (buffer);
765786 }
@@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
778799 LOG (" inserting instruction suffix\n " );
779800 embd_inp.insert (embd_inp.end (), inp_sfx.begin (), inp_sfx.end ());
780801 }
802+ // chatml mode: insert assistant chat suffix
803+ if (params.chatml ) {
804+ LOG (" inserting chatml suffix\n " );
805+ embd_inp.insert (embd_inp.end (), cml_sfx.begin (), cml_sfx.end ());
806+ }
781807
782808 for (size_t i = original_size; i < embd_inp.size (); ++i) {
783809 const llama_token token = embd_inp[i];
@@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
803829 }
804830
805831 // end of text token
806- if (!embd.empty () && embd.back () == llama_token_eos (model) && !(params.instruct || params.interactive )) {
832+ if (!embd.empty () && embd.back () == llama_token_eos (model) && !(params.instruct || params.interactive || params. chatml )) {
807833 LOG_TEE (" [end of text]\n " );
808834 break ;
809835 }
0 commit comments