Skip to content

Commit

Permalink
Add optional llm_samplers sampler backend
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 authored and AmineDiro committed Aug 15, 2023
1 parent b1249c9 commit 643b6a7
Show file tree
Hide file tree
Showing 14 changed files with 452 additions and 223 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }
llm-samplers = "=0.0.5"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
2 changes: 2 additions & 0 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ tracing-appender = "0.2.2"
# Remove this once we bump our MSRV to 1.70.
is-terminal = "0.4"

llm-samplers = { workspace = true }

[dev-dependencies]
rusty-hook = "^0.11.2"

Expand Down
115 changes: 70 additions & 45 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ use std::{
fmt,
ops::Deref,
path::{Path, PathBuf},
sync::Arc,
str::FromStr,
};

use clap::{Parser, ValueEnum};
use clap::{ArgAction, Parser, ValueEnum};
use color_eyre::eyre::{self, WrapErr};
use llm::{
ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias,
LoadProgress, Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias,
TokenizerSource,
ggml_format,
samplers::{build_sampler, ConfiguredSampler},
ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress,
Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias, TokenId, TokenizerSource,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -244,29 +245,62 @@ pub struct Generate {
#[arg(long, default_value_t = 8)]
pub batch_size: usize,

/// Size of the 'last N' buffer that is used for the `repeat_penalty`
/// option. In tokens.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,

/// The penalty for repeating tokens. Higher values make the generation less
/// likely to get into a loop, but may harm results when repetitive outputs
/// are desired.
#[arg(long, default_value_t = 1.30)]
pub repeat_penalty: f32,

/// Temperature
#[arg(long, default_value_t = 0.80)]
pub temperature: f32,

/// Top-K: The top K words by score are kept during sampling.
#[arg(long, default_value_t = 40)]
pub top_k: usize,

/// Top-p: The cumulative probability after which no more words are kept
/// for sampling.
#[arg(long, default_value_t = 0.95)]
pub top_p: f32,
/// Configure sampler settings using a string in the format: sampler_name:key1=value1:key2=value2
/// This option may be specified multiple times.
/// NOTE: If mirostat1 or mirostat2 samplers are configured then samplers other than repetition, frequency/presence and temperature will be ignored.
/// TIPS:
/// 1. Sampler options aren't required, but the colon after the name is. For example "mirostat1:" will enable Mirostat 1 with its default options.
/// 2. It's possible to specify partial option names, as long as they are unambiguous.
/// 3. You can skip the underscore in sampler (but not option) names.
///
/// Configurable samplers (defaults shown in parenthesis):
///
/// freq_presence (default: disabled) - Allows penalizing tokens for presence and frequency.
/// frequency_penalty(0.0): Penalty to apply to tokens based on frequency. For example, if a token has appeared 3 times within the last_n range then it will have its probability decreased by 3 * frequency_penalty.
/// presence_penalty(0.0): Penalty to apply to tokens that are already present within the last_n tokens.
/// last_n(64): Number of previous tokens to consider.
///
/// locally_typical (default: disabled) - An approach to sampling that attempts to maximize natural and human-like output. See: https://arxiv.org/abs/2202.00666
/// p(1.0): Referred to as τ in the paper. It suggests using 0.2 as a value for story generation and `0.95` for "abstractive summarization" (presumably this means more factual output). 1.0 appears to be the same as disabled which is similar to top-p sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// mirostat1 (default: disabled) - See: https://arxiv.org/abs/2007.14966
/// eta(0.1): Learning rate
/// tau(5.0): Target entropy
/// mu(tau * 2): Initial learning state value. Setting this is generally not recommended.
///
/// mirostat2 (default: disabled) - See: https://arxiv.org/abs/2007.14966
/// eta(0.1): Learning rate
/// tau(5.0): Target entropy
/// mu(tau * 2): Initial learning state value. Setting this is generally not recommended.
///
/// repetition - Allows setting a repetition penalty.
/// penalty(1.30): The penalty for repeating tokens. Higher values make the generation less likely to get into a loop, but may harm results when repetitive outputs are desired.
/// last_n(64): Number of previous tokens to consider.
///
/// tail_free (default: disabled) - An approach to sampling that attempts to outperform existing nucleus (top-p and top-k) methods. See: https://trentbrick.github.io/Tail-Free-Sampling/
/// z(1.0): It is not entirely clear what a reasonable value here is but 1.0 appears to be the same as disabled which is similar to top-p sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// temperature - Temperature used for sampling.
/// temperature(0.8): Temperature (randomness) used for sampling. A higher number is more random.
///
/// top_k - The top k (or min_keep if it is greater) tokens by score are kept during sampling.
/// k(40): Number of tokens to keep.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen.
/// p(0.95): The cumulative probability after which no more tokens are kept for sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
#[arg(
long = "sampler",
short = 's',
verbatim_doc_comment,
action = ArgAction::Append,
value_parser = ConfiguredSampler::from_str
)]
pub configured_samplers: Vec<ConfiguredSampler>,

/// Specifies the seed to use during sampling. Note that, depending on
/// hardware, the same seed may lead to different results on two separate
Expand Down Expand Up @@ -295,8 +329,7 @@ pub struct Generate {
pub token_bias: Option<TokenBias>,

/// Prevent the end of stream (EOS/EOD) token from being generated. This will allow the
/// model to generate text until it runs out of context space. Note: The --token-bias
/// option will override this if specified.
/// model to generate text until it runs out of context space.
#[arg(long, default_value_t = false)]
pub ignore_eos: bool,

Expand Down Expand Up @@ -348,25 +381,17 @@ impl Generate {
}
}

pub fn inference_parameters(&self, eot: llm::TokenId) -> InferenceParameters {
pub fn inference_parameters(&self, eot: TokenId, n_vocab: usize) -> InferenceParameters {
let mut bias: Vec<(TokenId, f32)> = self.token_bias.clone().unwrap_or_default().into();
if self.ignore_eos {
bias.push((eot, f32::NEG_INFINITY));
}
InferenceParameters {
sampler: Arc::new(llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
repeat_penalty: self.repeat_penalty,
temperature: self.temperature,
bias_tokens: self.token_bias.clone().unwrap_or_else(|| {
if self.ignore_eos {
TokenBias::new(vec![(eot, -1.0)])
} else {
TokenBias::default()
}
}),
repetition_penalty_last_n: self.repeat_last_n,
}),
sampler: build_sampler(n_vocab, &bias, self.configured_samplers.clone()),
}
}
}

fn parse_bias(s: &str) -> Result<TokenBias, InvalidTokenBias> {
s.parse()
}
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn initialize_common_state(
let model = model_load.load(generate.use_gpu)?;
Ok((
generate.inference_session_config(),
generate.inference_parameters(model.eot_token_id()),
generate.inference_parameters(model.eot_token_id(), model.tokenizer().len()),
model,
generate.rng(),
))
Expand Down
7 changes: 6 additions & 1 deletion binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
args.load_session.as_deref(),
inference_session_config,
);
let parameters = args.generate.inference_parameters(model.eot_token_id());
let parameters = args
.generate
.inference_parameters(model.eot_token_id(), model.tokenizer().len());

let mut rng = args.generate.rng();

Expand Down Expand Up @@ -94,6 +96,9 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
Err(llm::InferenceError::TokenizationFailed(err)) => {
log::error!("A tokenization-related failure occurred: {}", err);
}
Err(llm::InferenceError::SamplerFailure(err)) => {
log::error!("A sampling-related failure occurred: {}", err);
}
Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => {
unreachable!("cannot fail")
}
Expand Down
1 change: 1 addition & 0 deletions binaries/llm-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ clap = { workspace = true }
env_logger = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
llm-samplers = { workspace = true }

reqwest = "0.11.9"
indicatif = "0.16.2"
Expand Down
55 changes: 31 additions & 24 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
//!
//! See [crate::TestCase::Inference].
use std::{convert::Infallible, sync::Arc};
use std::{
convert::Infallible,
sync::{Arc, Mutex},
};

use llm::{InferenceSessionConfig, InferenceStats};
use llm::{InferenceSessionConfig, InferenceStats, TokenId};

use llm_samplers::prelude::{HasSamplerResources, Logits, SampleFlatBias, SampleGreedy, Sampler};

use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta};

Expand Down Expand Up @@ -66,7 +71,7 @@ fn run_inference(
&llm::InferenceRequest {
prompt: input.into(),
parameters: &llm::InferenceParameters {
sampler: Arc::new(DeterministicSampler),
sampler: Arc::new(Mutex::new(DeterministicSampler::default())),
},
play_back_previous_tokens: false,
maximum_token_count: Some(maximum_token_count),
Expand All @@ -84,27 +89,29 @@ fn run_inference(
(actual_output, res)
}

#[derive(Debug)]
struct DeterministicSampler;
impl llm::Sampler for DeterministicSampler {
fn sample(
&self,
previous_tokens: &[llm::TokenId],
logits: &[f32],
_rng: &mut dyn rand::RngCore,
) -> llm::TokenId {
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
let mut logits = logits.to_vec();
for &token in previous_tokens {
logits[token as usize] = f32::NEG_INFINITY;
}
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
#[derive(Debug, Default)]
struct DeterministicSampler(SampleGreedy<TokenId>);

impl Sampler<TokenId, f32> for DeterministicSampler {
fn sample<'a>(
&mut self,
res: &mut dyn HasSamplerResources<TokenId = TokenId>,
logits: &'a mut Logits<TokenId, f32>,
) -> anyhow::Result<&'a mut Logits<TokenId, f32>> {
let mut flat_bias = Default::default();

// This might look a little weird, but it's necessary because the resource
// `with_` functions can't return a value.
res.with_last_tokens(&mut |lt| {
flat_bias = SampleFlatBias::new(lt.iter().map(|tid| (*tid, f32::NEG_INFINITY)));
})?;

logits.sample(res, &mut flat_bias)?.sample(res, &mut self.0)
}

logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as llm::TokenId
fn sampled_token_id(&self) -> Option<TokenId> {
*self.0
}
}
2 changes: 2 additions & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ tokenizers = {version="0.13.3", default-features=false, features=["onig"]}
regex = "1.8"
tracing = { workspace = true }

llm-samplers = { workspace = true }

[features]
tokenizers-remote = ["tokenizers/http"]
cublas = ["ggml/cublas"]
Expand Down
11 changes: 10 additions & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,13 @@ impl InferenceSession {
return Err(InferenceError::ContextFull);
}

let next_token = params.sampler.sample(&self.tokens, &self.last_logits, rng);
let next_token = crate::samplers::sample_token(
params.sampler.clone(),
rng,
&self.tokens,
self.last_logits.iter().copied(),
)
.map_err(InferenceError::SamplerFailure)?;

// Update the tokens for this session
self.tokens.push(next_token);
Expand Down Expand Up @@ -687,6 +693,9 @@ pub enum InferenceError {
#[error("the user-specified callback returned an error")]
/// The user-specified callback returned an error.
UserCallback(Box<dyn std::error::Error + Send + Sync>),
/// Sampling returned an error.
#[error("token sampling failed")]
SamplerFailure(Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Error, Debug)]
Expand Down
Loading

0 comments on commit 643b6a7

Please sign in to comment.