Skip to content

Commit

Permalink
Rename flag to --token-bias, add --ignore-eos alias.
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Mar 23, 2023
1 parent 8e94159 commit 67950ec
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 7 additions & 1 deletion llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ pub struct Args {
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
#[arg(long, default_value = None, value_parser = parse_bias)]
pub bias_tokens: Option<TokenBias>,
pub token_bias: Option<TokenBias>,

/// Prevent the end of stream (EOS/EOD) token from being generated. This will allow the
/// model to generate text until it runs out of context space. Note: The --token-bias
/// option will override this if specified.
#[arg(long, default_value_t = false)]
pub ignore_eos: bool,

/// Dumps the prompt to console and exits, first as a comma seperated list of token IDs
/// and then as a list of comma seperated string keys and token ID values.
Expand Down
10 changes: 8 additions & 2 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{convert::Infallible, io::Write};
use cli_args::CLI_ARGS;
use llama_rs::{
InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot,
ModelKVMemoryType, Vocabulary,
ModelKVMemoryType, TokenBias, Vocabulary, EOD_TOKEN_ID,
};
use rand::thread_rng;
use rand::SeedableRng;
Expand Down Expand Up @@ -105,7 +105,13 @@ fn main() {
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
temp: args.temp,
bias_tokens: args.bias_tokens.clone().unwrap_or_default(),
bias_tokens: args.token_bias.clone().unwrap_or_else(|| {
if args.ignore_eos {
TokenBias::new(vec![(EOD_TOKEN_ID, -1.0)])
} else {
TokenBias::default()
}
}),
};
let inference_session_params = {
let mem_typ = if args.float16 {
Expand Down
10 changes: 6 additions & 4 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use thiserror::Error;
use partial_sort::PartialSort;
use rand::{distributions::WeightedIndex, prelude::Distribution};

pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct Hyperparameters {
n_vocab: i32,
Expand Down Expand Up @@ -1361,11 +1363,11 @@ impl InferenceSession {
model.evaluate(self, params.n_threads, &[next_token]);

// Return the next token
if next_token == 2 {
Ok(OutputToken::EndOfText)
Ok(if next_token as TokenId == EOD_TOKEN_ID {
OutputToken::EndOfText
} else {
Ok(OutputToken::Token(&vocab.id_to_token[next_token as usize]))
}
OutputToken::Token(&vocab.id_to_token[next_token as usize])
})
}

// todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe?
Expand Down

0 comments on commit 67950ec

Please sign in to comment.