Skip to content

Commit

Permalink
Update to llm-samplers v0.0.7
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Nov 6, 2023
1 parent aeafa44 commit 5fa9bb2
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 28 deletions.
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }
llm-samplers = "=0.0.6"
llm-samplers = { git = "https://github.com/KerfuffleV2/llm-samplers", branch = "feat-v0.0.7" }
# llm-samplers = "=0.0.6"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
9 changes: 9 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ pub struct Generate {
/// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen.
/// p(0.95): The cumulative probability after which no more tokens are kept for sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information.
/// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler.
/// a2(0.0): Threshold power. A reasonable value is 2.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information.
/// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
#[arg(long = "sampler", short = 's', verbatim_doc_comment)]
pub sampler_options: Vec<String>,

Expand Down
10 changes: 5 additions & 5 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ fn run_inference(
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
#[derive(Debug, Default)]
struct DeterministicSampler(SampleGreedy<TokenId>);
struct DeterministicSampler(SampleGreedy);

impl Sampler<TokenId, f32> for DeterministicSampler {
impl Sampler for DeterministicSampler {
fn sample<'a>(
&mut self,
res: &mut dyn HasSamplerResources<TokenId = TokenId>,
logits: &'a mut Logits<TokenId, f32>,
) -> anyhow::Result<&'a mut Logits<TokenId, f32>> {
res: &mut dyn HasSamplerResources,
logits: &'a mut Logits,
) -> anyhow::Result<&'a mut Logits> {
let mut flat_bias = Default::default();

// This might look a little weird, but it's necessary because the resource
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub struct InferenceParameters {
/// This can be anything that implements [Sampler]. Refer to
/// the `llm-samplers` documentation for possible samplers and suggested
/// combinations: <https://docs.rs/llm-samplers>
pub sampler: Arc<Mutex<dyn Sampler<TokenId, f32>>>,
pub sampler: Arc<Mutex<dyn Sampler>>,
}

//Since Sampler implements Send and Sync, InferenceParameters should too.
Expand Down
52 changes: 34 additions & 18 deletions crates/llm-base/src/samplers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub enum SamplingError {
/// to ensure a valid configuration.
pub struct ConfiguredSamplers {
/// A builder from the `llm-samplers` crate.
pub builder: SamplerChainBuilder,
pub builder: SamplerChainBuilder<usize, f32>,
/// Mirostat 1 is present.
pub mirostat1: bool,
/// Mirostat 2 is present.
Expand All @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers {
/// We call a configuration of samplers that run in a certain order a "chain".
/// Here is a description of the default chain `llm` uses:
///
/// 1. Repetition (present by default, multiple allowed)
/// 2. Frequency/Presence (optional, multiple allowed)
/// 3. Sequence Repetition (optional, multiple allowed)
/// 4. Top-K (present by default - incompatible with Mirostat)
/// 5. Tail Free (optional - incompatible with Mirostat)
/// 6. Locally Typical (optional - incompatible with Mirostat)
/// 7. Top-P (present by default - incompatible with Mirostat)
/// 8. Temperature (present by default)
/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution.
/// 1. Repetition (present by default, multiple allowed)
/// 2. Frequency/Presence (optional, multiple allowed)
/// 3. Sequence Repetition (optional, multiple allowed)
/// 4. Top-K (present by default - incompatible with Mirostat)
/// 5. Tail Free (optional - incompatible with Mirostat)
/// 6. Locally Typical (optional - incompatible with Mirostat)
/// 7. Top-P (present by default - incompatible with Mirostat)
/// 8. Top-A (optional - incompatible with Mirostat)
/// 9. Min-P (optional - incompatible with Mirostat)
/// 10. Temperature (present by default)
/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution.
///
/// Samplers listed as "present by default" but incompatible with Mirostat will
/// only be enabled by default if there is no Mirostat sampler enabled.
Expand Down Expand Up @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers {
Option::<SampleTopP>::None,
),
),
(
"topa",
SamplerSlot::new_single(
|| Box::new(SampleTopA::default().a1(0.0).a2(0.0)),
Option::<SampleTopA>::None,
),
),
(
"minp",
SamplerSlot::new_single(
|| Box::new(SampleMinP::default().p(0.0)),
Option::<SampleMinP>::None,
),
),
(
"temperature",
SamplerSlot::new_single(
Expand Down Expand Up @@ -203,7 +219,7 @@ impl ConfiguredSamplers {
))?
} else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat {
Err(SamplerConfigurationError::SamplerCombinationError(
"Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(),
"Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(),
))?
}
Ok(())
Expand Down Expand Up @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers {
.inspect(|(name, _slot)| match name.as_str() {
"mirostat1" => result.mirostat1 = true,
"mirostat2" => result.mirostat2 = true,
"topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true,
"topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => {
result.incompat_mirostat = true
}
_ => (),
})
.collect::<Vec<_>>();
Expand All @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers {
/// Sample a token. This convenience function handles building
/// the sampler resources and logits objects the sampler needs.
pub fn sample_token(
mut sampler: impl Sampler<TokenId, f32>,
mut sampler: impl Sampler,
rng: &mut impl rand::Rng,
previous_tokens: &[TokenId],
last_logits: impl IntoIterator<Item = f32>,
Expand Down Expand Up @@ -297,7 +315,7 @@ pub fn build_sampler(
n_vocab: usize,
bias: &[(TokenId, f32)],
args: &[impl AsRef<str>],
) -> Result<Arc<Mutex<dyn Sampler<TokenId, f32>>>, SamplerConfigurationError> {
) -> Result<Arc<Mutex<dyn Sampler>>, SamplerConfigurationError> {
let mut samplers = SamplerChain::new();

if !bias.is_empty() {
Expand Down Expand Up @@ -326,7 +344,7 @@ pub fn build_sampler(
}

/// Get the default sampler chain.
pub fn default_samplers() -> Arc<Mutex<dyn Sampler<TokenId, f32>>> {
pub fn default_samplers() -> Arc<Mutex<dyn Sampler>> {
let mut result = ConfiguredSamplers::default();
result.ensure_default_slots();
Arc::new(Mutex::new(result.builder.into_chain()))
Expand All @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> {
}

impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> {
type TokenId = TokenId;

fn with_rng_mut(
&mut self,
fun: &mut dyn FnMut(&mut dyn rand::RngCore),
Expand All @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> {
Ok(())
}

fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> {
fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> {
fun(self.previous_tokens);
Ok(())
}
Expand Down

0 comments on commit 5fa9bb2

Please sign in to comment.