@@ -512,6 +512,7 @@ void chat(
512512 int prev_token;
513513 int pos = 0 ; // position in the sequence
514514 while (pos < steps) {
515+
515516 // when it is the user's turn to contribute tokens to the dialog...
516517 if (user_turn) {
517518 // get the (optional) system prompt at position 0
@@ -538,19 +539,21 @@ void chat(
538539 }
539540 // render user/system prompts into the Llama 2 Chat schema
540541 if (pos == 0 && system_prompt[0 ] != ' \0 ' ) {
541- const char system_template[] = " <s>[INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
542+ // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
543+ const char system_template[] = " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
542544 snprintf (
543545 rendered_prompt, RENDERED_PROMPT_SIZE-1 , system_template, system_prompt, user_prompt);
544546 } else {
545547 // Assistant should produce </s>, so we do not include it in template
546- // "</s><s>[INST] %s [/INST]" for subsequent user inputs.
547- const char user_template[] = " <s> [INST] %s [/INST]" ;
548+ // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
549+ const char user_template[] = " [INST] %s [/INST]" ;
548550 snprintf (rendered_prompt, RENDERED_PROMPT_SIZE-1 , user_template, user_prompt);
549551 }
550552
551553 // encode the rendered prompt into tokens
552554 prompt_tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
553555 num_prompt_tokens = prompt_tokens.size ();
556+
554557 user_idx = 0 ; // reset the user index
555558 user_turn = 0 ;
556559 printf (" Assistant: " );
@@ -566,27 +569,27 @@ void chat(
566569 token = next;
567570 }
568571
572+ // forward the transformer to get logits for the next token
573+ float * logits = forward (transformer, token, pos);
574+ next = sample (sampler, logits);
575+
576+
569577 if (token == EOS_TOKEN) {
570578 user_turn = 1 ;
571- pos++;
572- } else {
573- // forward the transformer to get logits for the next token
574- float * logits = forward (transformer, token, pos);
575- next = sample (sampler, logits);
576- pos++;
577-
578- if (user_idx >= num_prompt_tokens && next != EOS_TOKEN && next != SOS_TOKEN) {
579- // the Assistant is responding, so print its output
580- std::string piece = tokenizer->decode (token, next);
581- safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
582- // "unsafe" bytes
583- fflush (stdout);
584- }
585- if (next == EOS_TOKEN) {
586- printf (" \n " );
587- }
588579 }
589580
581+ if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) {
582+ std::string piece = tokenizer->decode (token, next);
583+ safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
584+ // "unsafe" bytes
585+ fflush (stdout);
586+ }
587+
588+ if (next == EOS_TOKEN) {
589+ printf (" \n " );
590+ }
591+ pos++;
592+
590593 }
591594 printf (" \n " );
592595}
0 commit comments