From e7de8dedf90390d6e7e69b5901b9a39001072e2e Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 23 Mar 2023 06:52:48 -0600 Subject: [PATCH 1/3] Add --bias-tokens and --dump-prompt-tokens flags and associated backend changes. --bias-tokens allows passing a list of token ids and float point biases. This can be used to emulate llama.cpp's --ignore-eos feature in a more flexible way. For example using "--bias-tokens 2=-1" will prevent the EOD token from being generated. --dump-prompt-tokens will just dump the tokens for the prompt after loading, in two formats: First just as a list of token ids and then as a list with the associated string the token came from. Some backend changes were necessary to make this possible but the behavior of the existing code should remain unchanged. One thing I did do was move the tokenize method to Vocabulary and changed it to generate Vec<(&str, TokenId)>. The tokenize method in session remains as a wrapper to that which just strips out the string part. --- llama-cli/src/cli_args.rs | 19 ++++ llama-cli/src/main.rs | 34 +++++++- llama-rs/src/lib.rs | 178 ++++++++++++++++++++++++++------------ 3 files changed, 177 insertions(+), 54 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index d5b84e28..f71c8f27 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -1,4 +1,5 @@ use clap::Parser; +use llama_rs::TokenBias; use once_cell::sync::Lazy; #[derive(Parser, Debug)] @@ -82,6 +83,24 @@ pub struct Args { /// from the cache. #[arg(long, default_value_t = false)] pub float16: bool, + + /// A comma separated list of token biases. The list should be in the format + /// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a + /// floating point number. + /// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 + /// (start of document) and 2 (end of document) to -1.0 which effectively + /// disables the model from generating responses containing those token IDs. + #[arg(long, default_value = None, value_parser = parse_bias)] + pub bias_tokens: Option, + + /// Dumps the prompt to console and exits, first as a comma seperated list of token IDs + /// and then as a list of comma seperated string keys and token ID values. + #[arg(long, default_value_t = false)] + pub dump_prompt_tokens: bool, +} + +fn parse_bias(s: &str) -> Result { + s.parse() } /// CLI args are stored in a lazy static variable so they're accessible from diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index d9808ff8..2ea92ace 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -3,7 +3,7 @@ use std::{convert::Infallible, io::Write}; use cli_args::CLI_ARGS; use llama_rs::{ InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot, - ModelKVMemoryType, + ModelKVMemoryType, Vocabulary, }; use rand::thread_rng; use rand::SeedableRng; @@ -64,6 +64,32 @@ fn repl_mode( } } +fn dump_tokens(text: &str, vocab: &Vocabulary) -> Result<(), InferenceError> { + let toks = match vocab.tokenize(text, false) { + Ok(toks) => toks, + Err(e) => { + log::error!("Could not tokenize prompt: {e}"); + return Err(e); + } + }; + log::info!("=== Dumping prompt tokens:"); + log::info!( + "{}", + toks.iter() + .map(|(_, tid)| tid.to_string()) + .collect::>() + .join(", ") + ); + log::info!( + "{}", + toks.iter() + .map(|(s, tid)| format!("{s:?}:{tid}")) + .collect::>() + .join(", ") + ); + Ok(()) +} + fn main() { env_logger::builder() .filter_level(log::LevelFilter::Info) @@ -79,6 +105,7 @@ fn main() { top_p: args.top_p, repeat_penalty: args.repeat_penalty, temp: args.temp, + bias_tokens: args.bias_tokens.clone().unwrap_or_default(), }; let inference_session_params = { let mem_typ = if args.float16 { @@ -164,6 +191,11 @@ fn main() { log::info!("Model fully loaded!"); + if args.dump_prompt_tokens { + dump_tokens(&prompt, &vocab).ok(); + return; + } + let mut rng = if let Some(seed) = CLI_ARGS.seed { rand::rngs::StdRng::seed_from_u64(seed) } else { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index fcf7e7a9..04e68a8c 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -5,6 +5,7 @@ use std::{ fmt::Display, io::{BufRead, Read, Seek, SeekFrom}, path::{Path, PathBuf}, + str::FromStr, time, }; @@ -131,6 +132,7 @@ pub struct InferenceParameters { pub top_p: f32, pub repeat_penalty: f32, pub temp: f32, + pub bias_tokens: TokenBias, } impl Default for InferenceParameters { @@ -142,6 +144,7 @@ impl Default for InferenceParameters { top_p: 0.95, repeat_penalty: 1.30, temp: 0.80, + bias_tokens: TokenBias::default(), } } } @@ -246,6 +249,55 @@ impl Display for OutputToken<'_> { } } +#[derive(Default, Clone, Debug, PartialEq)] +pub struct TokenBias(Vec<(TokenId, f32)>); + +impl TokenBias { + pub fn new(mut v: Vec<(TokenId, f32)>) -> Self { + v.sort_by_cached_key(|(tid, _)| *tid); + v.dedup_by_key(|(tid, _)| *tid); + Self(v) + } + + pub fn get(&self, tid: TokenId) -> f32 { + self.0 + .binary_search_by_key(&tid, |(tid, _)| *tid) + .map_or(0.0, |idx| self.0[idx].1) + } +} + +impl FromStr for TokenBias { + type Err = String; + + fn from_str(s: &str) -> Result { + let x = s + .split(',') + .map(|kv| { + let (k, v) = kv + .trim() + .split_once('=') + .ok_or_else(|| "Missing '=' in bias item".to_owned())?; + let tid: TokenId = k + .trim() + .parse() + .map_err(|e: std::num::ParseIntError| e.to_string())?; + let bias: f32 = v + .trim() + .parse() + .map_err(|e: std::num::ParseFloatError| e.to_string())?; + Result::<_, String>::Ok((tid, bias)) + }) + .collect::>()?; + Ok(TokenBias::new(x)) + } +} + +impl std::fmt::Display for TokenBias { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + /// Each variant represents a step within the process of loading the model. /// These can be used to report progress to the user. #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] @@ -895,18 +947,24 @@ impl Model { { let scale = 1.0 / params.temp; for (i, &logit) in logits.iter().enumerate() { + let tid = i as TokenId; + let bias = params.bias_tokens.get(tid); + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if session.last_n_tokens.contains(&(i as TokenId)) { + let val = if bias != 0.0 { + bias + } else if session.last_n_tokens.contains(&tid) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability if logits[i] < 0.0 { - logits_id.push((logit * scale * params.repeat_penalty, i as TokenId)); + logit * scale * params.repeat_penalty } else { - logits_id.push((logit * scale / params.repeat_penalty, i as TokenId)); + logit * scale / params.repeat_penalty } } else { - logits_id.push((logit * scale, i as TokenId)); - } + logit * scale + }; + logits_id.push((val, tid)); } } @@ -1201,60 +1259,17 @@ impl Model { session.n_past += input_tokens.len(); } - // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece pub fn tokenize( &self, vocab: &Vocabulary, text: &str, bos: bool, ) -> Result, InferenceError> { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(vocab.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let Ok(sub) = std::str::from_utf8(sub) else { continue; }; - let token = vocab.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; - } - } - } - } - - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(InferenceError::TokenizationFailed); - } - res.push(token_id); - let token = &vocab.id_to_token[token_id as usize]; - i -= token.len(); - } - - if bos { - // TODO: replace with vocab.bos - res.push(1); - } - - // Pieces are in reverse order so correct that - res.reverse(); - - Ok(res) + Ok(vocab + .tokenize(text, bos)? + .iter() + .map(|(_, tid)| *tid) + .collect::>()) } /// Sets the state of the model, from a previously obtained InferenceSnapshot @@ -1460,3 +1475,60 @@ impl InferenceSnapshot { Self::read(&mut reader) } } + +impl Vocabulary { + // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece + pub fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, InferenceError> { + let len = text.len(); + + let mut score = vec![0usize; len + 1]; + let mut prev = vec![TokenId::default(); len + 1]; + + for i in 0..len { + let max_len = (len - i).min(self.max_token_length); + for sub_len in 1..=max_len { + let sub = &text.as_bytes()[i..i + sub_len]; + let Ok(sub) = std::str::from_utf8(sub) else { continue; }; + let token = self.token_to_id.get(sub); + + if let Some(token) = token { + let token_score = sub.len() * sub.len(); + let local_score = score[i] + token_score; + let next = i + sub_len; + + if score[next] < local_score { + score[next] = local_score; + prev[next] = *token; + } + } + } + } + + // Backward pass + let mut res = vec![]; + let mut i = len; + while i > 0 { + let token_id = prev[i]; + if token_id == 0 { + return Err(InferenceError::TokenizationFailed); + } + let token = self.id_to_token[token_id as usize].as_str(); + res.push((token, token_id)); + i -= token.len(); + } + + if bos { + // TODO: replace with vocab.bos + res.push(("", 1)); + } + + // Pieces are in reverse order so correct that + res.reverse(); + + Ok(res) + } +} From 8e9415936935d0a5df2317d60c752f9e953e50f0 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 23 Mar 2023 12:52:18 -0600 Subject: [PATCH 2/3] Change TokenBias::get to return Option. Change name of bias binding to logit_override in sample_top_p_k. --- llama-rs/src/lib.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 04e68a8c..53f89c36 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -259,10 +259,11 @@ impl TokenBias { Self(v) } - pub fn get(&self, tid: TokenId) -> f32 { + pub fn get(&self, tid: TokenId) -> Option { self.0 .binary_search_by_key(&tid, |(tid, _)| *tid) - .map_or(0.0, |idx| self.0[idx].1) + .map(|idx| self.0[idx].1) + .ok() } } @@ -948,12 +949,11 @@ impl Model { let scale = 1.0 / params.temp; for (i, &logit) in logits.iter().enumerate() { let tid = i as TokenId; - let bias = params.bias_tokens.get(tid); // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - let val = if bias != 0.0 { - bias + let val = if let Some(logit_override) = params.bias_tokens.get(tid) { + logit_override } else if session.last_n_tokens.contains(&tid) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability if logits[i] < 0.0 { From 67950ecba87b437820eb0931b1248eb0262d2e7b Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 23 Mar 2023 14:03:52 -0600 Subject: [PATCH 3/3] Rename flag to --token-bias, add --ignore-eos alias. --- llama-cli/src/cli_args.rs | 8 +++++++- llama-cli/src/main.rs | 10 ++++++++-- llama-rs/src/lib.rs | 10 ++++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index f71c8f27..e6b6dc61 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -91,7 +91,13 @@ pub struct Args { /// (start of document) and 2 (end of document) to -1.0 which effectively /// disables the model from generating responses containing those token IDs. #[arg(long, default_value = None, value_parser = parse_bias)] - pub bias_tokens: Option, + pub token_bias: Option, + + /// Prevent the end of stream (EOS/EOD) token from being generated. This will allow the + /// model to generate text until it runs out of context space. Note: The --token-bias + /// option will override this if specified. + #[arg(long, default_value_t = false)] + pub ignore_eos: bool, /// Dumps the prompt to console and exits, first as a comma seperated list of token IDs /// and then as a list of comma seperated string keys and token ID values. diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 2ea92ace..b4b20488 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -3,7 +3,7 @@ use std::{convert::Infallible, io::Write}; use cli_args::CLI_ARGS; use llama_rs::{ InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot, - ModelKVMemoryType, Vocabulary, + ModelKVMemoryType, TokenBias, Vocabulary, EOD_TOKEN_ID, }; use rand::thread_rng; use rand::SeedableRng; @@ -105,7 +105,13 @@ fn main() { top_p: args.top_p, repeat_penalty: args.repeat_penalty, temp: args.temp, - bias_tokens: args.bias_tokens.clone().unwrap_or_default(), + bias_tokens: args.token_bias.clone().unwrap_or_else(|| { + if args.ignore_eos { + TokenBias::new(vec![(EOD_TOKEN_ID, -1.0)]) + } else { + TokenBias::default() + } + }), }; let inference_session_params = { let mem_typ = if args.float16 { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 53f89c36..f714b343 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -14,6 +14,8 @@ use thiserror::Error; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; +pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) + #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct Hyperparameters { n_vocab: i32, @@ -1361,11 +1363,11 @@ impl InferenceSession { model.evaluate(self, params.n_threads, &[next_token]); // Return the next token - if next_token == 2 { - Ok(OutputToken::EndOfText) + Ok(if next_token as TokenId == EOD_TOKEN_ID { + OutputToken::EndOfText } else { - Ok(OutputToken::Token(&vocab.id_to_token[next_token as usize])) - } + OutputToken::Token(&vocab.id_to_token[next_token as usize]) + }) } // todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe?