Skip to content

Commit

Permalink
Merge pull request rustformers#64 from KerfuffleV2/feat-token_bias
Browse files Browse the repository at this point in the history
Add --bias-tokens and --dump-prompt-tokens flags and associated backend changes.
  • Loading branch information
setzer22 authored Mar 23, 2023
2 parents 309ecb7 + 67950ec commit 4cd6cb1
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 58 deletions.
25 changes: 25 additions & 0 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use clap::Parser;
use llama_rs::TokenBias;
use once_cell::sync::Lazy;

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -82,6 +83,30 @@ pub struct Args {
/// from the cache.
#[arg(long, default_value_t = false)]
pub float16: bool,

/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
#[arg(long, default_value = None, value_parser = parse_bias)]
pub token_bias: Option<TokenBias>,

/// Prevent the end of stream (EOS/EOD) token from being generated. This will allow the
/// model to generate text until it runs out of context space. Note: The --token-bias
/// option will override this if specified.
#[arg(long, default_value_t = false)]
pub ignore_eos: bool,

/// Dumps the prompt to console and exits, first as a comma seperated list of token IDs
/// and then as a list of comma seperated string keys and token ID values.
#[arg(long, default_value_t = false)]
pub dump_prompt_tokens: bool,
}

fn parse_bias(s: &str) -> Result<TokenBias, String> {
s.parse()
}

/// CLI args are stored in a lazy static variable so they're accessible from
Expand Down
40 changes: 39 additions & 1 deletion llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{convert::Infallible, io::Write};
use cli_args::CLI_ARGS;
use llama_rs::{
InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot,
ModelKVMemoryType,
ModelKVMemoryType, TokenBias, Vocabulary, EOD_TOKEN_ID,
};
use rand::thread_rng;
use rand::SeedableRng;
Expand Down Expand Up @@ -64,6 +64,32 @@ fn repl_mode(
}
}

fn dump_tokens(text: &str, vocab: &Vocabulary) -> Result<(), InferenceError> {
let toks = match vocab.tokenize(text, false) {
Ok(toks) => toks,
Err(e) => {
log::error!("Could not tokenize prompt: {e}");
return Err(e);
}
};
log::info!("=== Dumping prompt tokens:");
log::info!(
"{}",
toks.iter()
.map(|(_, tid)| tid.to_string())
.collect::<Vec<_>>()
.join(", ")
);
log::info!(
"{}",
toks.iter()
.map(|(s, tid)| format!("{s:?}:{tid}"))
.collect::<Vec<_>>()
.join(", ")
);
Ok(())
}

fn main() {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
Expand All @@ -79,6 +105,13 @@ fn main() {
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
temp: args.temp,
bias_tokens: args.token_bias.clone().unwrap_or_else(|| {
if args.ignore_eos {
TokenBias::new(vec![(EOD_TOKEN_ID, -1.0)])
} else {
TokenBias::default()
}
}),
};
let inference_session_params = {
let mem_typ = if args.float16 {
Expand Down Expand Up @@ -164,6 +197,11 @@ fn main() {

log::info!("Model fully loaded!");

if args.dump_prompt_tokens {
dump_tokens(&prompt, &vocab).ok();
return;
}

let mut rng = if let Some(seed) = CLI_ARGS.seed {
rand::rngs::StdRng::seed_from_u64(seed)
} else {
Expand Down
188 changes: 131 additions & 57 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
fmt::Display,
io::{BufRead, Read, Seek, SeekFrom},
path::{Path, PathBuf},
str::FromStr,
time,
};

Expand All @@ -13,6 +14,8 @@ use thiserror::Error;
use partial_sort::PartialSort;
use rand::{distributions::WeightedIndex, prelude::Distribution};

pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct Hyperparameters {
n_vocab: i32,
Expand Down Expand Up @@ -131,6 +134,7 @@ pub struct InferenceParameters {
pub top_p: f32,
pub repeat_penalty: f32,
pub temp: f32,
pub bias_tokens: TokenBias,
}

impl Default for InferenceParameters {
Expand All @@ -142,6 +146,7 @@ impl Default for InferenceParameters {
top_p: 0.95,
repeat_penalty: 1.30,
temp: 0.80,
bias_tokens: TokenBias::default(),
}
}
}
Expand Down Expand Up @@ -246,6 +251,56 @@ impl Display for OutputToken<'_> {
}
}

#[derive(Default, Clone, Debug, PartialEq)]
pub struct TokenBias(Vec<(TokenId, f32)>);

impl TokenBias {
pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
v.sort_by_cached_key(|(tid, _)| *tid);
v.dedup_by_key(|(tid, _)| *tid);
Self(v)
}

pub fn get(&self, tid: TokenId) -> Option<f32> {
self.0
.binary_search_by_key(&tid, |(tid, _)| *tid)
.map(|idx| self.0[idx].1)
.ok()
}
}

impl FromStr for TokenBias {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let x = s
.split(',')
.map(|kv| {
let (k, v) = kv
.trim()
.split_once('=')
.ok_or_else(|| "Missing '=' in bias item".to_owned())?;
let tid: TokenId = k
.trim()
.parse()
.map_err(|e: std::num::ParseIntError| e.to_string())?;
let bias: f32 = v
.trim()
.parse()
.map_err(|e: std::num::ParseFloatError| e.to_string())?;
Result::<_, String>::Ok((tid, bias))
})
.collect::<Result<_, _>>()?;
Ok(TokenBias::new(x))
}
}

impl std::fmt::Display for TokenBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}

/// Each variant represents a step within the process of loading the model.
/// These can be used to report progress to the user.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
Expand Down Expand Up @@ -895,18 +950,23 @@ impl Model {
{
let scale = 1.0 / params.temp;
for (i, &logit) in logits.iter().enumerate() {
let tid = i as TokenId;

// repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if session.last_n_tokens.contains(&(i as TokenId)) {
let val = if let Some(logit_override) = params.bias_tokens.get(tid) {
logit_override
} else if session.last_n_tokens.contains(&tid) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if logits[i] < 0.0 {
logits_id.push((logit * scale * params.repeat_penalty, i as TokenId));
logit * scale * params.repeat_penalty
} else {
logits_id.push((logit * scale / params.repeat_penalty, i as TokenId));
logit * scale / params.repeat_penalty
}
} else {
logits_id.push((logit * scale, i as TokenId));
}
logit * scale
};
logits_id.push((val, tid));
}
}

Expand Down Expand Up @@ -1201,60 +1261,17 @@ impl Model {
session.n_past += input_tokens.len();
}

// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
pub fn tokenize(
&self,
vocab: &Vocabulary,
text: &str,
bos: bool,
) -> Result<Vec<TokenId>, InferenceError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(vocab.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let Ok(sub) = std::str::from_utf8(sub) else { continue; };
let token = vocab.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}

// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
res.push(token_id);
let token = &vocab.id_to_token[token_id as usize];
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push(1);
}

// Pieces are in reverse order so correct that
res.reverse();

Ok(res)
Ok(vocab
.tokenize(text, bos)?
.iter()
.map(|(_, tid)| *tid)
.collect::<Vec<TokenId>>())
}

/// Sets the state of the model, from a previously obtained InferenceSnapshot
Expand Down Expand Up @@ -1346,11 +1363,11 @@ impl InferenceSession {
model.evaluate(self, params.n_threads, &[next_token]);

// Return the next token
if next_token == 2 {
Ok(OutputToken::EndOfText)
Ok(if next_token as TokenId == EOD_TOKEN_ID {
OutputToken::EndOfText
} else {
Ok(OutputToken::Token(&vocab.id_to_token[next_token as usize]))
}
OutputToken::Token(&vocab.id_to_token[next_token as usize])
})
}

// todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe?
Expand Down Expand Up @@ -1460,3 +1477,60 @@ impl InferenceSnapshot {
Self::read(&mut reader)
}
}

impl Vocabulary {
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
pub fn tokenize<'a>(
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a str, TokenId)>, InferenceError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let Ok(sub) = std::str::from_utf8(sub) else { continue; };
let token = self.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}

// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_str();
res.push((token, token_id));
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push(("", 1));
}

// Pieces are in reverse order so correct that
res.reverse();

Ok(res)
}
}

0 comments on commit 4cd6cb1

Please sign in to comment.