From ec9052e3cd81fc2ed2d1eb8a76d1b3662d25fcf4 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Fri, 7 Jul 2023 05:50:03 -0600 Subject: [PATCH] Add optional llm_samplers sampler backend --- Cargo.lock | 24 ++ Cargo.toml | 1 + binaries/llm-cli/Cargo.toml | 2 + binaries/llm-cli/src/cli_args.rs | 115 +++--- binaries/llm-cli/src/interactive.rs | 2 +- binaries/llm-cli/src/main.rs | 7 +- binaries/llm-test/Cargo.toml | 1 + binaries/llm-test/src/inference.rs | 55 +-- crates/llm-base/Cargo.toml | 2 + crates/llm-base/src/inference_session.rs | 11 +- crates/llm-base/src/lib.rs | 15 +- crates/llm-base/src/samplers.rs | 432 +++++++++++++++-------- crates/llm-base/src/tokenizer/mod.rs | 6 + crates/llm/src/lib.rs | 2 +- 14 files changed, 452 insertions(+), 223 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1365e9a..65ca703a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1285,6 +1285,7 @@ dependencies = [ "bytemuck", "ggml", "half", + "llm-samplers", "memmap2", "partial_sort", "rand", @@ -1314,6 +1315,7 @@ dependencies = [ "env_logger", "is-terminal", "llm", + "llm-samplers", "log", "num_cpus", "rand", @@ -1370,6 +1372,18 @@ dependencies = [ "llm-base", ] +[[package]] +name = "llm-samplers" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eeb5ff99b934436bd5ff7d1cdc3a674bf5b99ab996f36214306b2565cf393872" +dependencies = [ + "anyhow", + "num-traits", + "rand", + "thiserror", +] + [[package]] name = "llm-test" version = "0.2.0-dev" @@ -1379,6 +1393,7 @@ dependencies = [ "env_logger", "indicatif", "llm", + "llm-samplers", "log", "rand", "reqwest", @@ -1571,6 +1586,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" diff --git a/Cargo.toml b/Cargo.toml index 59ad9021..f702cfca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } +llm-samplers = "=0.0.5" # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/binaries/llm-cli/Cargo.toml b/binaries/llm-cli/Cargo.toml index 1cd6c34d..5064b18c 100644 --- a/binaries/llm-cli/Cargo.toml +++ b/binaries/llm-cli/Cargo.toml @@ -35,6 +35,8 @@ tracing-appender = "0.2.2" # Remove this once we bump our MSRV to 1.70. is-terminal = "0.4" +llm-samplers = { workspace = true } + [dev-dependencies] rusty-hook = "^0.11.2" diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index dd107695..9b02703c 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -2,15 +2,16 @@ use std::{ fmt, ops::Deref, path::{Path, PathBuf}, - sync::Arc, + str::FromStr, }; -use clap::{Parser, ValueEnum}; +use clap::{ArgAction, Parser, ValueEnum}; use color_eyre::eyre::{self, WrapErr}; use llm::{ - ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, - LoadProgress, Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias, - TokenizerSource, + ggml_format, + samplers::{build_sampler, ConfiguredSampler}, + ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress, + Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias, TokenId, TokenizerSource, }; use rand::SeedableRng; @@ -244,29 +245,62 @@ pub struct Generate { #[arg(long, default_value_t = 8)] pub batch_size: usize, - /// Size of the 'last N' buffer that is used for the `repeat_penalty` - /// option. In tokens. - #[arg(long, default_value_t = 64)] - pub repeat_last_n: usize, - - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - #[arg(long, default_value_t = 1.30)] - pub repeat_penalty: f32, - - /// Temperature - #[arg(long, default_value_t = 0.80)] - pub temperature: f32, - - /// Top-K: The top K words by score are kept during sampling. - #[arg(long, default_value_t = 40)] - pub top_k: usize, - - /// Top-p: The cumulative probability after which no more words are kept - /// for sampling. - #[arg(long, default_value_t = 0.95)] - pub top_p: f32, + /// Configure sampler settings using a string in the format: sampler_name:key1=value1:key2=value2 + /// This option may be specified multiple times. + /// NOTE: If mirostat1 or mirostat2 samplers are configured then samplers other than repetition, frequency/presence and temperature will be ignored. + /// TIPS: + /// 1. Sampler options aren't required, but the colon after the name is. For example "mirostat1:" will enable Mirostat 1 with its default options. + /// 2. It's possible to specify partial option names, as long as they are unambiguous. + /// 3. You can skip the underscore in sampler (but not option) names. + /// + /// Configurable samplers (defaults shown in parenthesis): + /// + /// freq_presence (default: disabled) - Allows penalizing tokens for presence and frequency. + /// frequency_penalty(0.0): Penalty to apply to tokens based on frequency. For example, if a token has appeared 3 times within the last_n range then it will have its probability decreased by 3 * frequency_penalty. + /// presence_penalty(0.0): Penalty to apply to tokens that are already present within the last_n tokens. + /// last_n(64): Number of previous tokens to consider. + /// + /// locally_typical (default: disabled) - An approach to sampling that attempts to maximize natural and human-like output. See: https://arxiv.org/abs/2202.00666 + /// p(1.0): Referred to as τ in the paper. It suggests using 0.2 as a value for story generation and `0.95` for "abstractive summarization" (presumably this means more factual output). 1.0 appears to be the same as disabled which is similar to top-p sampling. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// mirostat1 (default: disabled) - See: https://arxiv.org/abs/2007.14966 + /// eta(0.1): Learning rate + /// tau(5.0): Target entropy + /// mu(tau * 2): Initial learning state value. Setting this is generally not recommended. + /// + /// mirostat2 (default: disabled) - See: https://arxiv.org/abs/2007.14966 + /// eta(0.1): Learning rate + /// tau(5.0): Target entropy + /// mu(tau * 2): Initial learning state value. Setting this is generally not recommended. + /// + /// repetition - Allows setting a repetition penalty. + /// penalty(1.30): The penalty for repeating tokens. Higher values make the generation less likely to get into a loop, but may harm results when repetitive outputs are desired. + /// last_n(64): Number of previous tokens to consider. + /// + /// tail_free (default: disabled) - An approach to sampling that attempts to outperform existing nucleus (top-p and top-k) methods. See: https://trentbrick.github.io/Tail-Free-Sampling/ + /// z(1.0): It is not entirely clear what a reasonable value here is but 1.0 appears to be the same as disabled which is similar to top-p sampling. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// temperature - Temperature used for sampling. + /// temperature(0.8): Temperature (randomness) used for sampling. A higher number is more random. + /// + /// top_k - The top k (or min_keep if it is greater) tokens by score are kept during sampling. + /// k(40): Number of tokens to keep. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen. + /// p(0.95): The cumulative probability after which no more tokens are kept for sampling. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + + #[arg( + long = "sampler", + short = 's', + verbatim_doc_comment, + action = ArgAction::Append, + value_parser = ConfiguredSampler::from_str + )] + pub configured_samplers: Vec, /// Specifies the seed to use during sampling. Note that, depending on /// hardware, the same seed may lead to different results on two separate @@ -295,8 +329,7 @@ pub struct Generate { pub token_bias: Option, /// Prevent the end of stream (EOS/EOD) token from being generated. This will allow the - /// model to generate text until it runs out of context space. Note: The --token-bias - /// option will override this if specified. + /// model to generate text until it runs out of context space. #[arg(long, default_value_t = false)] pub ignore_eos: bool, @@ -348,25 +381,17 @@ impl Generate { } } - pub fn inference_parameters(&self, eot: llm::TokenId) -> InferenceParameters { + pub fn inference_parameters(&self, eot: TokenId, n_vocab: usize) -> InferenceParameters { + let mut bias: Vec<(TokenId, f32)> = self.token_bias.clone().unwrap_or_default().into(); + if self.ignore_eos { + bias.push((eot, f32::NEG_INFINITY)); + } InferenceParameters { - sampler: Arc::new(llm::samplers::TopPTopK { - top_k: self.top_k, - top_p: self.top_p, - repeat_penalty: self.repeat_penalty, - temperature: self.temperature, - bias_tokens: self.token_bias.clone().unwrap_or_else(|| { - if self.ignore_eos { - TokenBias::new(vec![(eot, -1.0)]) - } else { - TokenBias::default() - } - }), - repetition_penalty_last_n: self.repeat_last_n, - }), + sampler: build_sampler(n_vocab, &bias, self.configured_samplers.clone()), } } } + fn parse_bias(s: &str) -> Result { s.parse() } diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs index ae72aa61..1ffac4a5 100644 --- a/binaries/llm-cli/src/interactive.rs +++ b/binaries/llm-cli/src/interactive.rs @@ -125,7 +125,7 @@ fn initialize_common_state( let model = model_load.load(generate.use_gpu)?; Ok(( generate.inference_session_config(), - generate.inference_parameters(model.eot_token_id()), + generate.inference_parameters(model.eot_token_id(), model.tokenizer().len()), model, generate.rng(), )) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index f6377516..f05f89fe 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -47,7 +47,9 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> { args.load_session.as_deref(), inference_session_config, ); - let parameters = args.generate.inference_parameters(model.eot_token_id()); + let parameters = args + .generate + .inference_parameters(model.eot_token_id(), model.tokenizer().len()); let mut rng = args.generate.rng(); @@ -94,6 +96,9 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> { Err(llm::InferenceError::TokenizationFailed(err)) => { log::error!("A tokenization-related failure occurred: {}", err); } + Err(llm::InferenceError::SamplerFailure(err)) => { + log::error!("A sampling-related failure occurred: {}", err); + } Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { unreachable!("cannot fail") } diff --git a/binaries/llm-test/Cargo.toml b/binaries/llm-test/Cargo.toml index c1e349b8..cfb392c9 100644 --- a/binaries/llm-test/Cargo.toml +++ b/binaries/llm-test/Cargo.toml @@ -17,6 +17,7 @@ clap = { workspace = true } env_logger = { workspace = true } log = { workspace = true } rand = { workspace = true } +llm-samplers = { workspace = true } reqwest = "0.11.9" indicatif = "0.16.2" diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs index 5190bb9e..a9ace889 100644 --- a/binaries/llm-test/src/inference.rs +++ b/binaries/llm-test/src/inference.rs @@ -2,9 +2,14 @@ //! //! See [crate::TestCase::Inference]. -use std::{convert::Infallible, sync::Arc}; +use std::{ + convert::Infallible, + sync::{Arc, Mutex}, +}; -use llm::{InferenceSessionConfig, InferenceStats}; +use llm::{InferenceSessionConfig, InferenceStats, TokenId}; + +use llm_samplers::prelude::{HasSamplerResources, Logits, SampleFlatBias, SampleGreedy, Sampler}; use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta}; @@ -66,7 +71,7 @@ fn run_inference( &llm::InferenceRequest { prompt: input.into(), parameters: &llm::InferenceParameters { - sampler: Arc::new(DeterministicSampler), + sampler: Arc::new(Mutex::new(DeterministicSampler::default())), }, play_back_previous_tokens: false, maximum_token_count: Some(maximum_token_count), @@ -84,27 +89,29 @@ fn run_inference( (actual_output, res) } -#[derive(Debug)] -struct DeterministicSampler; -impl llm::Sampler for DeterministicSampler { - fn sample( - &self, - previous_tokens: &[llm::TokenId], - logits: &[f32], - _rng: &mut dyn rand::RngCore, - ) -> llm::TokenId { - // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` - // at all - let mut logits = logits.to_vec(); - for &token in previous_tokens { - logits[token as usize] = f32::NEG_INFINITY; - } +// Takes the most likely element from the logits, except if they've appeared in `previous_tokens` +// at all +#[derive(Debug, Default)] +struct DeterministicSampler(SampleGreedy); + +impl Sampler for DeterministicSampler { + fn sample<'a>( + &mut self, + res: &mut dyn HasSamplerResources, + logits: &'a mut Logits, + ) -> anyhow::Result<&'a mut Logits> { + let mut flat_bias = Default::default(); + + // This might look a little weird, but it's necessary because the resource + // `with_` functions can't return a value. + res.with_last_tokens(&mut |lt| { + flat_bias = SampleFlatBias::new(lt.iter().map(|tid| (*tid, f32::NEG_INFINITY))); + })?; + + logits.sample(res, &mut flat_bias)?.sample(res, &mut self.0) + } - logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .unwrap() - .0 as llm::TokenId + fn sampled_token_id(&self) -> Option { + *self.0 } } diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index 0216525f..4e6de051 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -26,6 +26,8 @@ tokenizers = {version="0.13.3", default-features=false, features=["onig"]} regex = "1.8" tracing = { workspace = true } +llm-samplers = { workspace = true } + [features] tokenizers-remote = ["tokenizers/http"] cublas = ["ggml/cublas"] diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index c60eefc3..bb8bc2b2 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -385,7 +385,13 @@ impl InferenceSession { return Err(InferenceError::ContextFull); } - let next_token = params.sampler.sample(&self.tokens, &self.last_logits, rng); + let next_token = crate::samplers::sample_token( + params.sampler.clone(), + rng, + &self.tokens, + self.last_logits.iter().copied(), + ) + .map_err(InferenceError::SamplerFailure)?; // Update the tokens for this session self.tokens.push(next_token); @@ -687,6 +693,9 @@ pub enum InferenceError { #[error("the user-specified callback returned an error")] /// The user-specified callback returned an error. UserCallback(Box), + /// Sampling returned an error. + #[error("token sampling failed")] + SamplerFailure(Box), } #[derive(Error, Debug)] diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index fae0de9e..80a34452 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -17,7 +17,7 @@ pub mod model; pub mod samplers; pub mod util; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; pub use ggml; pub use ggml::Type as ElementType; @@ -28,6 +28,7 @@ pub use inference_session::{ InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError, }; +pub use llm_samplers::prelude::{Sampler, SamplerChain}; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, LoadError, LoadProgress, Loader, TensorLoader, @@ -37,7 +38,6 @@ pub use memmap2::Mmap; pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest}; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use regex::Regex; -pub use samplers::Sampler; pub use tokenizer::{ InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Tokenizer, TokenizerLoadError, TokenizerSource, @@ -57,9 +57,10 @@ pub struct InferenceParameters { /// from this distribution to generate the next token. Using a different sampler may /// change the output of the model, or control how deterministic the generated text is. /// - /// A recommended default sampler is [TopPTopK](samplers::TopPTopK), which is a standard - /// sampler that offers a [Default](samplers::TopPTopK::default) implementation. - pub sampler: Arc, + /// This can be anything that implements [Sampler]. Refer to + /// the `llm-samplers` documentation for possible samplers and suggested + /// combinations: + pub sampler: Arc>>, } //Since Sampler implements Send and Sync, InferenceParameters should too. @@ -68,8 +69,10 @@ unsafe impl Sync for InferenceParameters {} impl Default for InferenceParameters { fn default() -> Self { + let chain: SamplerChain = + crate::samplers::ConfiguredSamplers::default().into(); Self { - sampler: Arc::new(samplers::TopPTopK::default()), + sampler: Arc::new(Mutex::new(chain)), } } } diff --git a/crates/llm-base/src/samplers.rs b/crates/llm-base/src/samplers.rs index 13196a19..c4438bea 100644 --- a/crates/llm-base/src/samplers.rs +++ b/crates/llm-base/src/samplers.rs @@ -1,162 +1,306 @@ -//! Defines the samplers used for generation. -//! -//! You can define your own [Sampler] by implementing the trait. - -use std::fmt::Debug; - -use partial_sort::PartialSort; -use rand::{distributions::WeightedIndex, prelude::Distribution}; - -use crate::{TokenBias, TokenId}; - -/// A sampler for generation. -pub trait Sampler: Debug + Send + Sync { - /// Given the previous tokens, the logits from the most recent evaluation, and a source of randomness, - /// sample from the logits and return the token ID. - fn sample( - &self, - previous_tokens: &[TokenId], - logits: &[f32], - rng: &mut dyn rand::RngCore, - ) -> TokenId; +//! Types and methods used for constructing and running +//! the samplers used for generation. + +use std::{ + error::Error, + fmt, + str::FromStr, + sync::{Arc, Mutex}, +}; + +use llm_samplers::prelude::*; + +use crate::TokenId; + +/// This structure holds specific samplers that have already +/// been configured and provides some convenience methods +/// for constructing samplers with default settings. +#[derive(Debug, Default, Clone)] +pub struct ConfiguredSamplers { + bias: Option>, + repetition: Option>, + freq_presence: Option>, + top_k: Option, + tail_free: Option>, + locally_typical: Option>, + top_p: Option>, + temperature: Option>, + mirostat1: Option>, + mirostat2: Option>, } -/// Top-P Top-K sampling. -/// -/// A standard sampler that uses top-K sampling (the top-K tokens with the highest -/// probability are considered) and top-P sampling (only tokens with a cumulative -/// probability of `P` are considered). -/// -/// It also implements [CTRL](https://arxiv.org/abs/1909.05858)'s repetition penalty, -/// and the ability to bias the generation of individual tokens. -#[derive(Clone, Debug)] -pub struct TopPTopK { - /// The top K words by score are kept during sampling. - pub top_k: usize, - /// The cumulative probability after which no more words are kept for sampling. - pub top_p: f32, - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - pub repeat_penalty: f32, - /// Temperature (randomness) used for sampling. A higher number is more random. - pub temperature: f32, - /// A list of tokens to bias against in the process of generation. - pub bias_tokens: TokenBias, - /// The number of tokens to consider for the repetition penalty. - pub repetition_penalty_last_n: usize, +impl ConfiguredSamplers { + /// Sets the token bias list + pub fn set_token_bias(&mut self, bias: impl IntoIterator) { + self.bias = Some(SampleFlatBias::new(bias)) + } + + /// Creates a temperature new sampler with default options. + pub fn new_temperature() -> SampleTemperature { + SampleTemperature::default().temperature(0.8) + } + + /// Creates a new repetition sampler with default options. + pub fn new_repetition() -> SampleRepetition { + SampleRepetition::default().penalty(1.30).last_n(64) + } + + /// Creates a new frequency/presence sampler with default options. + pub fn new_freq_presence() -> SampleFreqPresence { + SampleFreqPresence::default() + .frequency(0.0) + .presence(0.0) + .last_n(64) + } + + /// Creates a new top k sampler with default options. + pub fn new_top_k() -> SampleTopK { + SampleTopK::default().k(40) + } + + /// Creates a new top p sampler with default options. + pub fn new_top_p() -> SampleTopP { + SampleTopP::default().p(0.95) + } + + /// Creates a new tail free sampler with default options. + pub fn new_tail_free() -> SampleTailFree { + SampleTailFree::default().z(1.0) + } + + /// Creates a new locally typical sampler with default options. + pub fn new_locally_typical() -> SampleLocallyTypical { + SampleLocallyTypical::default().p(1.0) + } + + /// Creates a new mirostat 1 sampler with default options. + pub fn new_mirostat1() -> SampleMirostat1 { + SampleMirostat1::default().eta(0.1).tau(5.0) + } + + /// Creates a new mirostat 2 sampler with default options. + pub fn new_mirostat2() -> SampleMirostat2 { + SampleMirostat2::default().eta(0.1).tau(5.0) + } } -impl Default for TopPTopK { - fn default() -> Self { - Self { - top_k: 40, - top_p: 0.95, - repeat_penalty: 1.30, - temperature: 0.80, - bias_tokens: TokenBias::empty(), - repetition_penalty_last_n: 512, + +impl From for SamplerChain { + fn from(val: ConfiguredSamplers) -> Self { + let mut chain = SamplerChain::new(); + + if let Some(sampler) = val.bias { + chain += sampler; + } + if let Some(sampler) = val.repetition { + chain += sampler; + } + if let Some(sampler) = val.freq_presence { + chain += sampler; + } + + if let Some(mirosampler) = val.mirostat1 { + if let Some(sampler) = val.temperature { + chain += sampler; + } + chain += mirosampler; + return chain; + } else if let Some(mirosampler) = val.mirostat2 { + if let Some(sampler) = val.temperature { + chain += sampler; + } + chain += mirosampler; + return chain; + } + + if let Some(sampler) = val.top_k { + chain += sampler; + } + if let Some(sampler) = val.tail_free { + chain += sampler; + } + if let Some(sampler) = val.locally_typical { + chain += sampler; + } + if let Some(sampler) = val.top_p { + chain += sampler; + } + if let Some(sampler) = val.temperature { + chain += sampler; } + chain += SampleRandDistrib::new(); + chain } } -impl Sampler for TopPTopK { - fn sample( - &self, - previous_tokens: &[TokenId], - logits: &[f32], - rng: &mut dyn rand::RngCore, - ) -> TokenId { - let Self { - top_k, - top_p, - repeat_penalty, - temperature, - repetition_penalty_last_n, - .. - } = *self; - let bias_tokens = &self.bias_tokens; - - let n_logits = logits.len(); - let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits); - - // TODO: consider if this can be modularized and this sampler can be composed out of multiple pieces, - // instead of having this monolithic function that embeds the repetition penalty and token bias - { - let scale = 1.0 / temperature; - for (i, &logit) in logits.iter().enumerate() { - let tid = i as TokenId; - - let val = if let Some(logit_override) = bias_tokens.get(tid) { - logit_override - } else if previous_tokens[previous_tokens - .len() - .saturating_sub(repetition_penalty_last_n)..] - .contains(&(i as TokenId)) - { - // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if logits[i] < 0.0 { - logit * scale * repeat_penalty - } else { - logit * scale / repeat_penalty - } - } else { - logit * scale - }; - logits_id.push((val, tid)); + +impl ConfiguredSamplers { + fn from_args(args: Vec, n_vocab: usize) -> Self { + let mut result = Self::default(); + + args.into_iter().for_each(|arg| match arg { + ConfiguredSampler::Repetition(sampler) => result.repetition = Some(sampler), + ConfiguredSampler::FreqPresence(sampler) => result.freq_presence = Some(sampler), + ConfiguredSampler::TopK(sampler) => result.top_k = Some(sampler), + ConfiguredSampler::TailFree(sampler) => result.tail_free = Some(sampler), + ConfiguredSampler::LocallyTypical(sampler) => result.locally_typical = Some(sampler), + ConfiguredSampler::TopP(sampler) => result.top_p = Some(sampler), + ConfiguredSampler::Temperature(sampler) => result.temperature = Some(sampler), + ConfiguredSampler::Mirostat1(sampler) => { + result.mirostat1 = Some(sampler.n_vocab(n_vocab)) } - } + ConfiguredSampler::Mirostat2(sampler) => result.mirostat2 = Some(sampler), + }); - // find the top K tokens - { - logits_id.partial_sort(top_k, |a, b| { - // Sort descending - b.0.total_cmp(&a.0) - }); - logits_id.truncate(top_k); + if result.temperature.is_none() { + result.temperature = Some(ConfiguredSamplers::new_temperature()) + } + if result.repetition.is_none() { + result.repetition = Some(ConfiguredSamplers::new_repetition()) + } + if result.mirostat1.is_some() || result.mirostat2.is_some() { + return result; } - let maxl = logits_id - .iter() - .map(|x| x.0) - .max_by(f32::total_cmp) - .unwrap(); - - // compute probs for the top K tokens - let mut probs: Vec = logits_id - .iter() - .copied() - .map(|(k, _)| (k - maxl).exp()) - .collect(); - let sum: f32 = probs.iter().copied().sum(); - - // Normalize the probs - for p in probs.iter_mut() { - *p /= sum; + if result.top_k.is_none() { + result.top_k = Some(ConfiguredSamplers::new_top_k()) + } + if result.top_p.is_none() { + result.top_p = Some(ConfiguredSamplers::new_top_p()) } + result + } +} - // Top p sampling - if top_p < 1.0 { - let mut cumsum = 0.0; - for i in 0..probs.len() { - cumsum += probs[i]; - if cumsum >= top_p { - probs.truncate(i + 1); - logits_id.truncate(i + 1); - break; - } - } +/// A specific type of sampler that has been configured +#[derive(Clone, Debug)] +pub enum ConfiguredSampler { + /// Holds the configured sampler + Repetition(SampleRepetition), + /// Holds the configured sampler + FreqPresence(SampleFreqPresence), + /// Holds the configured sampler + TopK(SampleTopK), + /// Holds the configured sampler + TailFree(SampleTailFree), + /// Holds the configured sampler + LocallyTypical(SampleLocallyTypical), + /// Holds the configured sampler + TopP(SampleTopP), + /// Holds the configured sampler + Temperature(SampleTemperature), + /// Holds the configured sampler + Mirostat1(SampleMirostat1), + /// Holds the configured sampler + Mirostat2(SampleMirostat2), +} + +impl FromStr for ConfiguredSampler { + type Err = Box; - cumsum = 1.0 / cumsum; - for p in probs.iter_mut() { - *p *= cumsum; + fn from_str(s: &str) -> Result { + let (name, args) = if let Some(val) = s.split_once(':') { + val + } else { + return Err(Box::from("Bad format for sampler argument")); + }; + + Ok(match name.trim() { + "repetition" => ConfiguredSamplers::new_repetition() + .configure(args) + .map(Self::Repetition)?, + "frequency" | "presence" | "freqpresence" => ConfiguredSamplers::new_freq_presence() + .configure(args) + .map(Self::FreqPresence)?, + "topk" | "top_k" => { + ConfigurableSampler::<_, f32>::configure(ConfiguredSamplers::new_top_k(), args) + .map(Self::TopK)? } - } + "topp" | "top_p" => ConfiguredSamplers::new_top_p() + .configure(args) + .map(Self::TopP)?, + "temperature" | "temp" => ConfigurableSampler::::configure( + ConfiguredSamplers::new_temperature(), + args, + ) + .map(Self::Temperature)?, + "tailfree" | "tail_free" => ConfiguredSamplers::new_tail_free() + .configure(args) + .map(Self::TailFree)?, + "locallytypical" | "locally_typical" => ConfiguredSamplers::new_locally_typical() + .configure(args) + .map(Self::LocallyTypical)?, + "mirostat1" => ConfiguredSamplers::new_mirostat1() + .configure(args) + .map(Self::Mirostat1)?, + "mirostat2" => ConfiguredSamplers::new_mirostat2() + .configure(args) + .map(Self::Mirostat2)?, + unknown => return Err(Box::from(format!("Unknown sampler: {unknown}"))), + }) + } +} + +/// Sample a token. This convenience function handles building +/// the sampler resources and logits objects the sampler needs. +pub fn sample_token( + mut sampler: impl Sampler, + rng: &mut impl rand::Rng, + previous_tokens: &[TokenId], + last_logits: impl IntoIterator, +) -> Result> { + Logits::try_from_iter(last_logits.into_iter())? + .sample_token( + &mut SamplerResources { + previous_tokens, + rng, + }, + &mut sampler, + )? + .ok_or_else(|| Box::from("sampler did not return a token")) +} + +/// Build a sampler with the supplied options, vocab size and token bias list. +pub fn build_sampler( + n_vocab: usize, + bias: &[(TokenId, f32)], + args: Vec, +) -> Arc>> { + let mut settings = ConfiguredSamplers::from_args(args, n_vocab); + if !bias.is_empty() { + settings.set_token_bias(bias.iter().copied()) + } + let chain: SamplerChain = settings.into(); + Arc::new(Mutex::new(chain)) +} + +// Struct used to temporarily hold resources for the `llm_samplers` +// sampler. +struct SamplerResources<'pt, 'r> { + previous_tokens: &'pt [TokenId], + rng: &'r mut dyn rand::RngCore, +} + +impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SamplerResources") + .field("previous_tokens", &self.previous_tokens) + .field("rng", &"") + .finish() + } +} - let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); - let idx = dist.sample(rng); +impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { + type TokenId = TokenId; + + fn with_rng_mut( + &mut self, + fun: &mut dyn FnMut(&mut dyn rand::RngCore), + ) -> Result<(), SamplerError> { + fun(self.rng); + Ok(()) + } - logits_id[idx].1 + fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> { + fun(self.previous_tokens); + Ok(()) } } diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 52cbee00..03b2f0b9 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -298,6 +298,12 @@ impl TokenBias { } } +impl From for Vec<(TokenId, f32)> { + fn from(val: TokenBias) -> Self { + val.0 + } +} + impl FromStr for TokenBias { type Err = InvalidTokenBias; diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 37514511..febe2441 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -85,7 +85,7 @@ pub use llm_base::{ InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, - ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, Sampler, + ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, };