Skip to content

Commit

Permalink
Reimplement vs llm-samplers 0.0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Aug 6, 2023
1 parent ec9052e commit 500f068
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 248 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }
llm-samplers = "=0.0.5"
llm-samplers = "=0.0.6"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
54 changes: 25 additions & 29 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ use std::{
fmt,
ops::Deref,
path::{Path, PathBuf},
str::FromStr,
};

use clap::{ArgAction, Parser, ValueEnum};
use clap::{Parser, ValueEnum};
use color_eyre::eyre::{self, WrapErr};
use llm::{
ggml_format,
samplers::{build_sampler, ConfiguredSampler},
ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress,
Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias, TokenId, TokenizerSource,
ggml_format, samplers::build_sampler, ElementType, InferenceParameters, InferenceSessionConfig,
InvalidTokenBias, LoadProgress, Model, ModelKVMemoryType, ModelParameters, RoPEOverrides,
TokenBias, TokenId, TokenizerSource,
};
use rand::SeedableRng;

Expand Down Expand Up @@ -246,39 +244,39 @@ pub struct Generate {
pub batch_size: usize,

/// Configure sampler settings using a string in the format: sampler_name:key1=value1:key2=value2
/// This option may be specified multiple times.
/// NOTE: If mirostat1 or mirostat2 samplers are configured then samplers other than repetition, frequency/presence and temperature will be ignored.
/// To configure multiple samplers at once, separate the sampler configuration strings with space or '/' (forward slash).
/// NOTE: Mirostat samplers are incompatible with top-p, top-k, locally typical and tail free samplers.
/// TIPS:
/// 1. Sampler options aren't required, but the colon after the name is. For example "mirostat1:" will enable Mirostat 1 with its default options.
/// 1. Sampler options aren't required. For example "mirostat1" will enable Mirostat 1 with its default options.
/// 2. It's possible to specify partial option names, as long as they are unambiguous.
/// 3. You can skip the underscore in sampler (but not option) names.
/// 3. Underscore and dash are ignored in sampler names, so "top-p" is the same as "topp" or "top_p".
///
/// Configurable samplers (defaults shown in parenthesis):
///
/// freq_presence (default: disabled) - Allows penalizing tokens for presence and frequency.
/// freq_presence (default: disabled) - Allows penalizing tokens for presence and frequency. May be specified more than once.
/// frequency_penalty(0.0): Penalty to apply to tokens based on frequency. For example, if a token has appeared 3 times within the last_n range then it will have its probability decreased by 3 * frequency_penalty.
/// presence_penalty(0.0): Penalty to apply to tokens that are already present within the last_n tokens.
/// last_n(64): Number of previous tokens to consider.
///
/// locally_typical (default: disabled) - An approach to sampling that attempts to maximize natural and human-like output. See: https://arxiv.org/abs/2202.00666
/// locally_typical (default: disabled) - An approach to sampling that attempts to maximize natural and human-like output. See: <https://arxiv.org/abs/2202.00666>
/// p(1.0): Referred to as τ in the paper. It suggests using 0.2 as a value for story generation and `0.95` for "abstractive summarization" (presumably this means more factual output). 1.0 appears to be the same as disabled which is similar to top-p sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// mirostat1 (default: disabled) - See: https://arxiv.org/abs/2007.14966
/// mirostat1 (default: disabled) - See: <https://arxiv.org/abs/2007.14966>
/// eta(0.1): Learning rate
/// tau(5.0): Target entropy
/// mu(tau * 2): Initial learning state value. Setting this is generally not recommended.
///
/// mirostat2 (default: disabled) - See: https://arxiv.org/abs/2007.14966
/// mirostat2 (default: disabled) - See: <https://arxiv.org/abs/2007.14966>
/// eta(0.1): Learning rate
/// tau(5.0): Target entropy
/// mu(tau * 2): Initial learning state value. Setting this is generally not recommended.
///
/// repetition - Allows setting a repetition penalty.
/// repetition - Allows setting a repetition penalty. May be specified more than once.
/// penalty(1.30): The penalty for repeating tokens. Higher values make the generation less likely to get into a loop, but may harm results when repetitive outputs are desired.
/// last_n(64): Number of previous tokens to consider.
///
/// tail_free (default: disabled) - An approach to sampling that attempts to outperform existing nucleus (top-p and top-k) methods. See: https://trentbrick.github.io/Tail-Free-Sampling/
/// tail_free (default: disabled) - An approach to sampling that attempts to outperform existing nucleus (top-p and top-k) methods. See: <https://trentbrick.github.io/Tail-Free-Sampling/>
/// z(1.0): It is not entirely clear what a reasonable value here is but 1.0 appears to be the same as disabled which is similar to top-p sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
Expand All @@ -292,15 +290,8 @@ pub struct Generate {
/// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen.
/// p(0.95): The cumulative probability after which no more tokens are kept for sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
#[arg(
long = "sampler",
short = 's',
verbatim_doc_comment,
action = ArgAction::Append,
value_parser = ConfiguredSampler::from_str
)]
pub configured_samplers: Vec<ConfiguredSampler>,
#[arg(long = "sampler", short = 's', verbatim_doc_comment)]
pub sampler_options: Vec<String>,

/// Specifies the seed to use during sampling. Note that, depending on
/// hardware, the same seed may lead to different results on two separate
Expand Down Expand Up @@ -381,14 +372,19 @@ impl Generate {
}
}

pub fn inference_parameters(&self, eot: TokenId, n_vocab: usize) -> InferenceParameters {
pub fn inference_parameters(
&self,
eot: TokenId,
n_vocab: usize,
) -> eyre::Result<InferenceParameters> {
let mut bias: Vec<(TokenId, f32)> = self.token_bias.clone().unwrap_or_default().into();
if self.ignore_eos {
bias.push((eot, f32::NEG_INFINITY));
}
InferenceParameters {
sampler: build_sampler(n_vocab, &bias, self.configured_samplers.clone()),
}
Ok(InferenceParameters {
sampler: build_sampler(n_vocab, &bias, &self.sampler_options)
.map_err(|e| eyre::eyre!("Invalid sampler configuration: {e}"))?,
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn initialize_common_state(
let model = model_load.load(generate.use_gpu)?;
Ok((
generate.inference_session_config(),
generate.inference_parameters(model.eot_token_id(), model.tokenizer().len()),
generate.inference_parameters(model.eot_token_id(), model.tokenizer().len())?,
model,
generate.rng(),
))
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
);
let parameters = args
.generate
.inference_parameters(model.eot_token_id(), model.tokenizer().len());
.inference_parameters(model.eot_token_id(), model.tokenizer().len())?;

let mut rng = args.generate.rng();

Expand Down
4 changes: 1 addition & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ unsafe impl Sync for InferenceParameters {}

impl Default for InferenceParameters {
fn default() -> Self {
let chain: SamplerChain<TokenId, f32> =
crate::samplers::ConfiguredSamplers::default().into();
Self {
sampler: Arc::new(Mutex::new(chain)),
sampler: samplers::default_samplers(),
}
}
}
Loading

0 comments on commit 500f068

Please sign in to comment.