Skip to content

Commit

Permalink
refactor: remove model-associated inference params
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 23, 2023
1 parent c51ff5d commit 964b2cd
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 113 deletions.
1 change: 0 additions & 1 deletion binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ impl ModelLoad {
prefer_mmap: !self.no_mmap,
context_size: self.num_ctx_tokens,
lora_adapters: self.lora_paths.clone(),
..Default::default()
};

let mut sp = Some(spinoff::Spinner::new(
Expand Down
4 changes: 2 additions & 2 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn infer<M: llm::KnownModel + 'static>(
&mut rng,
&llm::InferenceRequest {
prompt: prompt.as_str().into(),
parameters: Some(&inference_params),
parameters: &inference_params,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
Expand Down Expand Up @@ -277,7 +277,7 @@ fn interactive<M: llm::KnownModel + 'static>(
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: Some(&inference_params),
parameters: &inference_params,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
Expand Down
8 changes: 3 additions & 5 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl InferenceSession {
let mut stats = InferenceStats::default();
let start_at = std::time::SystemTime::now();

let parameters = request.parameters.unwrap_or(model.inference_parameters());
let parameters = request.parameters;

// Feed the initial prompt through the transformer, to update its
// context window with new data.
Expand Down Expand Up @@ -635,15 +635,13 @@ impl Default for InferenceSessionConfig {
}
}

#[derive(Debug, PartialEq, Default, Clone, Copy)]
#[derive(Debug, PartialEq, Clone, Copy)]
/// Settings specific to [InferenceSession::infer].
pub struct InferenceRequest<'a> {
/// The prompt to feed to the model.
pub prompt: Prompt<'a>,
/// The parameters to use during this inference attempt.
/// If not specified, this will default to the parameters
/// specified in the model.
pub parameters: Option<&'a InferenceParameters>,
pub parameters: &'a InferenceParameters,
/// Whether or not to call the callback with the previous tokens
/// that were encountered in this session.
///
Expand Down
14 changes: 11 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,24 @@ pub struct InferenceParameters {
/// The number of tokens to consider for the repetition penalty.
pub repetition_penalty_last_n: usize,
}
impl Default for InferenceParameters {
fn default() -> Self {
impl InferenceParameters {
/// Returns a reasonable default for the parameters.
///
/// Note that these parameters are not necessarily optimal for all models, and that
/// you may want to tweak them for your use case.
///
/// This is intentionally not a `Default` implementation. The values specified here may change
/// in the future, and we want to make sure that users are aware of this and do not accidentally
/// rely on the values.
pub const fn reasonable_default() -> Self {
Self {
n_threads: 8,
n_batch: 8,
top_k: 40,
top_p: 0.95,
repeat_penalty: 1.30,
temperature: 0.80,
bias_tokens: TokenBias::default(),
bias_tokens: TokenBias::empty(),
repetition_penalty_last_n: 512,
}
}
Expand Down
17 changes: 0 additions & 17 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ pub trait KnownModel: Send + Sync {

/// Get the end of text/end of string token ID. This value is defined by model implementers.
fn eot_token_id(&self) -> TokenId;

/// Get the default [InferenceParameters] for this model (used by
/// [InferenceSession::infer]). This value is configured through
/// [ModelParameters::inference_parameters].
fn inference_parameters(&self) -> &InferenceParameters;
}

/// A type-erased model to allow for interacting with a model without knowing
Expand Down Expand Up @@ -200,11 +195,6 @@ pub trait Model: Send + Sync {

/// Get the end of text/end of string token ID. This value is defined by model implementers.
fn eot_token_id(&self) -> TokenId;

/// Get the default [InferenceParameters] for this model (used by
/// [InferenceSession::infer]). This value is configured through
/// [ModelParameters::inference_parameters].
fn inference_parameters(&self) -> &InferenceParameters;
}
impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
Expand Down Expand Up @@ -236,10 +226,6 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn eot_token_id(&self) -> TokenId {
KnownModel::eot_token_id(self)
}

fn inference_parameters(&self) -> &InferenceParameters {
KnownModel::inference_parameters(self)
}
}

/// Implemented by model hyperparameters for interacting with hyperparameters
Expand Down Expand Up @@ -280,8 +266,6 @@ pub struct ModelParameters {
/// The context size ("memory") the model should use when evaluating a prompt. A larger context
/// consumes more resources, but produces more consistent and coherent responses.
pub context_size: usize,
/// Default InferenceParameters to use when [evaluating](Model::evaluate) a prompt with this model.
pub inference_parameters: InferenceParameters,
/// The [LoRA](https://arxiv.org/abs/2106.09685) adapters to use when loading the model. If `None`, no adapters will be used.
pub lora_adapters: Option<Vec<PathBuf>>,
}
Expand All @@ -291,7 +275,6 @@ impl Default for ModelParameters {
Self {
prefer_mmap: true,
context_size: 2048,
inference_parameters: Default::default(),
lora_adapters: None,
}
}
Expand Down
5 changes: 5 additions & 0 deletions crates/llm-base/src/vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ impl<'a> From<&'a Vec<TokenId>> for Prompt<'a> {
pub struct TokenBias(Vec<(TokenId, f32)>);

impl TokenBias {
/// Create an empty [TokenBias].
pub const fn empty() -> Self {
Self(Vec::new())
}

/// Create a [TokenBias] from an existing `Vec`.
pub fn new(mut v: Vec<(TokenId, f32)>) -> Self {
v.sort_by_cached_key(|(tid, _)| *tid);
Expand Down
8 changes: 5 additions & 3 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use llm::{
load_progress_callback_stdout as load_callback, InferenceFeedback, InferenceRequest,
InferenceResponse, ModelArchitecture,
load_progress_callback_stdout as load_callback, InferenceFeedback, InferenceParameters,
InferenceRequest, InferenceResponse, ModelArchitecture,
};
use std::{convert::Infallible, io::Write, path::Path};

Expand Down Expand Up @@ -44,7 +44,9 @@ fn main() {
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.into(),
..Default::default()
parameters: &InferenceParameters::reasonable_default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
// OutputRequest
&mut Default::default(),
Expand Down
12 changes: 8 additions & 4 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use llm::{
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceStats, LoadProgress,
ModelArchitecture,
InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceStats,
LoadProgress, ModelArchitecture,
};
use rustyline::error::ReadlineError;
use spinoff::{spinners::Dots2, Spinner};
Expand Down Expand Up @@ -43,10 +43,12 @@ fn main() {
{character_name}: Paris is the capital of France."
);

let inference_parameters = InferenceParameters::reasonable_default();

session
.feed_prompt(
model.as_ref(),
&Default::default(),
&inference_parameters,
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
llm::feed_prompt_callback(prompt_callback),
Expand All @@ -73,7 +75,9 @@ fn main() {
prompt: format!("{user_name}: {line}\n{character_name}:")
.as_str()
.into(),
..Default::default()
parameters: &inference_parameters,
play_back_previous_tokens: false,
maximum_token_count: None,
},
&mut Default::default(),
inference_callback(String::from(user_name), &mut buf),
Expand Down
14 changes: 1 addition & 13 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ pub struct Bloom {
// weights for the model
layers: Vec<Layer>,

// default parameters used by [InferenceSession::infer]
inference_parameters: InferenceParameters,

// must be kept alive for the model
_context: ggml::Context,
_mmap: Option<Mmap>,
Expand Down Expand Up @@ -95,11 +92,7 @@ impl KnownModel for Bloom {

let (_context, _, _mmap) = tl.finish();

let ModelParameters {
context_size,
inference_parameters,
..
} = params;
let ModelParameters { context_size, .. } = params;

Ok(Bloom {
hyperparameters,
Expand All @@ -112,7 +105,6 @@ impl KnownModel for Bloom {
out_norm_bias,
output,
layers,
inference_parameters,
_context,
_mmap,
})
Expand Down Expand Up @@ -393,10 +385,6 @@ impl KnownModel for Bloom {
.copied()
.unwrap()
}

fn inference_parameters(&self) -> &InferenceParameters {
&self.inference_parameters
}
}

/// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
14 changes: 1 addition & 13 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ pub struct Gpt2 {
// weights for the model
layers: Vec<Layer>,

// default parameters used by [InferenceSession::infer]
inference_parameters: InferenceParameters,

// must be kept alive for the model
_context: ggml::Context,
_mmap: Option<Mmap>,
Expand Down Expand Up @@ -87,11 +84,7 @@ impl KnownModel for Gpt2 {

let (_context, _, _mmap) = tl.finish();

let ModelParameters {
context_size,
inference_parameters,
..
} = params;
let ModelParameters { context_size, .. } = params;

Ok(Gpt2 {
hyperparameters,
Expand All @@ -103,7 +96,6 @@ impl KnownModel for Gpt2 {
wte,
wpe,
lm_head,
inference_parameters,
_context,
_mmap,
})
Expand Down Expand Up @@ -349,10 +341,6 @@ impl KnownModel for Gpt2 {
.copied()
.unwrap()
}

fn inference_parameters(&self) -> &InferenceParameters {
&self.inference_parameters
}
}

/// GPT-2 [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
14 changes: 1 addition & 13 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ pub struct GptJ {
// weights for the model
layers: Vec<Layer>,

// default parameters used by [InferenceSession::infer]
inference_parameters: InferenceParameters,

// must be kept alive for the model
_context: ggml::Context,
_mmap: Option<Mmap>,
Expand Down Expand Up @@ -89,11 +86,7 @@ impl KnownModel for GptJ {

let (_context, _, _mmap) = tl.finish();

let ModelParameters {
context_size,
inference_parameters,
..
} = params;
let ModelParameters { context_size, .. } = params;

Ok(GptJ {
hyperparameters,
Expand All @@ -105,7 +98,6 @@ impl KnownModel for GptJ {
lmh_g,
lmh_b,
layers,
inference_parameters,
_mmap,
_context,
})
Expand Down Expand Up @@ -319,10 +311,6 @@ impl KnownModel for GptJ {
.copied()
.unwrap()
}

fn inference_parameters(&self) -> &InferenceParameters {
&self.inference_parameters
}
}

/// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
14 changes: 1 addition & 13 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ pub struct GptNeoX {
// weights for the model
layers: Vec<Layer>,

// default parameters used by [InferenceSession::infer]
inference_parameters: InferenceParameters,

// must be kept alive for the model
_context: ggml::Context,
_mmap: Option<Mmap>,
Expand Down Expand Up @@ -103,11 +100,7 @@ impl KnownModel for GptNeoX {

let (_context, _, _mmap) = tl.finish();

let ModelParameters {
context_size,
inference_parameters,
..
} = params;
let ModelParameters { context_size, .. } = params;

Ok(GptNeoX {
hyperparameters,
Expand All @@ -118,7 +111,6 @@ impl KnownModel for GptNeoX {
wte,
lmh_g,
layers,
inference_parameters,
_context,
_mmap,
})
Expand Down Expand Up @@ -400,10 +392,6 @@ impl KnownModel for GptNeoX {
.copied()
.unwrap()
}

fn inference_parameters(&self) -> &InferenceParameters {
&self.inference_parameters
}
}

/// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
Loading

0 comments on commit 964b2cd

Please sign in to comment.