Skip to content

Commit

Permalink
pub use the llm-samplers crate for convenience.
Browse files Browse the repository at this point in the history
Expose the sampler configuration structures to allow more flexibility.

Add more documentation and description for the sampling functions and structures.

Create specific enums for sampler construction and sampling errors.

Set n_vocab for the Mirostat 1 sampler in a more reliable way.
  • Loading branch information
KerfuffleV2 authored and AmineDiro committed Aug 15, 2023
1 parent 5ac346c commit 6424e4c
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 27 deletions.
2 changes: 1 addition & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ pub enum InferenceError {
UserCallback(Box<dyn std::error::Error + Send + Sync>),
/// Sampling returned an error.
#[error("token sampling failed")]
SamplerFailure(Box<dyn std::error::Error + Send + Sync>),
SamplerFailure(crate::samplers::SamplingError),
}

#[derive(Error, Debug)]
Expand Down
158 changes: 132 additions & 26 deletions crates/llm-base/src/samplers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Types and methods used for constructing and running
//! the samplers used for generation.
//!
//! The `llm-samplers` crate is also re-exported here for convenient use as `llm_samplers`.
use std::{
error::Error,
Expand All @@ -8,18 +10,88 @@ use std::{
sync::{Arc, Mutex},
};

use thiserror::Error;

pub use llm_samplers;

use llm_samplers::{configure::*, prelude::*};

use crate::TokenId;

#[derive(Debug, Error)]
/// Errors related to constructing samplers from string definitions.
pub enum SamplerConfigurationError {
#[error("An incompatible combination of samplers was requested: {0}")]
/// Not all combinations of samplers are valid. This error will be returned
/// when an invalid combination is specified.
SamplerCombinationError(String),

#[error("Error configuring sampler {name}: {err}")]
/// The sampler name was unknown or the options to it were invalid.
BuildSamplerError {
/// Name of the sampler that failed.
name: String,
/// The actual error.
err: Box<dyn Error + Send + Sync + 'static>,
},
}

#[derive(Debug, Error)]
/// Errors that occured during sampling.
pub enum SamplingError {
#[error("Sampling failed to produce a token")]
/// Sampling didn't produce a token.
NoToken,

#[error("An error occured constructing logits for sampling: {0}")]
/// Constructing logits failed. This can usually only happen if a logit is NaN.
LogitsError(Box<dyn Error + Send + Sync + 'static>),

#[error("An internal error occured during sampling: {0}")]
/// Sampling failed.
InternalSamplingError(Box<dyn Error + Send + Sync + 'static>),
}

#[derive(Debug)]
struct ConfiguredSamplers {
builder: SamplerChainBuilder,
mirostat1: bool,
mirostat2: bool,
incompat_mirostat: bool,
/// Used for configuring samplers dynamically from string definitions.
/// For example, commandline arguments. Constructing this structure manually is
/// not recommended. Use the [build_sampler] function or the [FromStr] instance
/// to ensure a valid configuration.
pub struct ConfiguredSamplers {
/// A builder from the `llm-samplers` crate.
pub builder: SamplerChainBuilder,
/// Mirostat 1 is present.
pub mirostat1: bool,
/// Mirostat 2 is present.
pub mirostat2: bool,
/// Samplers incompatible with Mirostat 1 and 2 are present.
pub incompat_mirostat: bool,
}

/// Construct a default instance of the structure. The `builder`
/// field contains a list of slots that may be optional.
///
/// We call a configuration of samplers that run in a certain order a "chain".
/// Here is a description of the default chain `llm` uses:
///
/// 1. Repetition (present by default, multiple allowed)
/// 2. Frequency/Presence (optional, multiple allowed)
/// 3. Sequence Repetition (optional, multiple allowed)
/// 4. Top-K (present by default - incompatible with Mirostat)
/// 5. Tail Free (optional - incompatible with Mirostat)
/// 6. Locally Typical (optional - incompatible with Mirostat)
/// 7. Top-P (present by default - incompatible with Mirostat)
/// 8. Temperature (present by default)
/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution.
///
/// Samplers listed as "present by default" but incompatible with Mirostat will
/// only be enabled by default if there is no Mirostat sampler enabled.
///
/// It's worth mentioning that "present by default" samplers that allow multiple instances
/// will add at least one entry if the user didn't specify the sampler. If they _did_ specify
/// it then no extra "default" sampler of that type will be added. So, for example,
/// if you wanted both the default Repetition sampler _and_ one with custom options, you'd
/// need to configure the Repetition sampler twice.
impl Default for ConfiguredSamplers {
fn default() -> Self {
Self {
Expand Down Expand Up @@ -100,6 +172,9 @@ impl Default for ConfiguredSamplers {
}

impl ConfiguredSamplers {
/// Ensures the default slots are populated after processing options.
/// Currently this is: temperature and repetition samplers
/// Then if neither Mirostat 1 or 2 are enabled: top-p and top-k.
pub fn ensure_default_slots(&mut self) {
self.builder.iter_mut().for_each(|(name, slot)| {
let mirostat = self.mirostat1 || self.mirostat2;
Expand All @@ -118,22 +193,34 @@ impl ConfiguredSamplers {
}
}

pub fn ensure_valid(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
/// Ensure that the configured samplers are compatible with each other.
/// For example, if Mirostat 1 and Mirostat 2 are enabled, this would
/// be invalid.
pub fn ensure_valid(&self) -> Result<(), SamplerConfigurationError> {
if self.mirostat1 && self.mirostat2 {
Err(Box::<dyn Error + Send + Sync>::from(
"Cannot enable both Mirostat 1 and Mirostat 2 samplers",
Err(SamplerConfigurationError::SamplerCombinationError(
"Cannot enable both Mirostat 1 and Mirostat 2 samplers".to_string(),
))?
} else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat {
Err(Box::<dyn Error + Send + Sync>::from(
"Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2",
Err(SamplerConfigurationError::SamplerCombinationError(
"Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(),
))?
}
Ok(())
}
}

/// The structure is generally build from a string definition.
/// Configuring as individual sampler takes the form `sampler_name:key1=value1:key2=value2`.
/// Underscore and dash are ignored when comparing sampler names and comparison is
/// case-insensitive. A partial key name may be specified as long as it's not ambiguous.
/// If the sampler only has one option (for example Temperature) the key and equals sign can
/// be left out entirely.
///
/// Separate multiple sampler configuration strings with space or forward slash.
/// Blank entries are allowed.
impl FromStr for ConfiguredSamplers {
type Err = Box<dyn Error + Send + Sync + 'static>;
type Err = SamplerConfigurationError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut result = Self::default();
Expand Down Expand Up @@ -163,8 +250,14 @@ impl FromStr for ConfiguredSamplers {
})
.collect::<Vec<_>>();

opts.into_iter()
.try_for_each(|(name, args)| result.builder.configure(name, args))?;
opts.into_iter().try_for_each(|(name, args)| {
result.builder.configure(&name, args).map_err(|err| {
SamplerConfigurationError::BuildSamplerError {
name: name.to_string(),
err: err.into(),
}
})
})?;

result.ensure_default_slots();
result.ensure_valid()?;
Expand All @@ -180,42 +273,55 @@ pub fn sample_token(
rng: &mut impl rand::Rng,
previous_tokens: &[TokenId],
last_logits: impl IntoIterator<Item = f32>,
) -> Result<TokenId, Box<dyn Error + Send + Sync>> {
Logits::try_from_iter(last_logits.into_iter())?
) -> Result<TokenId, SamplingError> {
Logits::try_from_iter(last_logits.into_iter())
.map_err(|err| SamplingError::LogitsError(err.into()))?
.sample_token(
&mut SamplerResources {
previous_tokens,
rng,
},
&mut sampler,
)?
.ok_or_else(|| Box::from("sampler did not return a token"))
)
.map_err(|err| SamplingError::InternalSamplingError(err.into()))?
.ok_or_else(|| SamplingError::NoToken)
}

/// Build a sampler with the supplied options, vocab size and token bias list.
#[allow(clippy::type_complexity)]
/// Build a sampler object with the supplied options, vocab size and token bias list.
///
/// Note that this is just a convenience function for building a sampler from
/// string definitions such as commandline arguments. The only limit on constructing
/// your own samplers is your sampler or samplers must implement the [Sampler] trait
/// from the `llm-samplers` crate.
pub fn build_sampler(
n_vocab: usize,
bias: &[(TokenId, f32)],
args: &[impl AsRef<str>],
) -> Result<Arc<Mutex<dyn Sampler<TokenId, f32>>>, Box<dyn std::error::Error + Send + Sync>> {
) -> Result<Arc<Mutex<dyn Sampler<TokenId, f32>>>, SamplerConfigurationError> {
let mut samplers = SamplerChain::new();

if !bias.is_empty() {
samplers += SampleFlatBias::new(bias.iter().copied());
}

let mut sampler_options = args
let sampler_options = args
.iter()
.map(|s| s.as_ref().trim())
.filter(|s| !s.is_empty())
.map(|s| "/".to_string() + s)
.collect::<String>();
if sampler_options.contains("/mirostat1") {
sampler_options += &format!("/mirostat1:n_vocab={n_vocab}");

let mut configured_samplers = ConfiguredSamplers::from_str(&sampler_options)?;
if configured_samplers.mirostat1 {
configured_samplers
.builder
.configure("mirostat1", format!("n_vocab={n_vocab}"))
.map_err(|err| SamplerConfigurationError::BuildSamplerError {
name: "mirostat1".to_string(),
err: err.into(),
})?;
}
let configured_samplers = ConfiguredSamplers::from_str(&sampler_options)?.builder;
samplers += configured_samplers.into_chain();
samplers += configured_samplers.builder.into_chain();
Ok(Arc::new(Mutex::new(samplers)))
}

Expand All @@ -226,7 +332,7 @@ pub fn default_samplers() -> Arc<Mutex<dyn Sampler<TokenId, f32>>> {
Arc::new(Mutex::new(result.builder.into_chain()))
}

// Struct used to temporarily hold resources for the `llm_samplers`
// Structure used to temporarily hold resources for the `llm-samplers`
// sampler.
struct SamplerResources<'pt, 'r> {
previous_tokens: &'pt [TokenId],
Expand Down

0 comments on commit 6424e4c

Please sign in to comment.