Skip to content

Commit

Permalink
refactor: remove model overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 29, 2023
1 parent d795bc8 commit 8f62b39
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 197 deletions.
6 changes: 1 addition & 5 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,7 @@ pub struct ModelLoad {
pub lora_paths: Option<Vec<PathBuf>>,
}
impl ModelLoad {
pub fn load<M: llm::KnownModel + 'static>(
&self,
overrides: Option<M::Overrides>,
) -> Result<Box<dyn Model>> {
pub fn load<M: llm::KnownModel + 'static>(&self) -> Result<Box<dyn Model>> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
context_size: self.num_ctx_tokens,
Expand Down Expand Up @@ -435,7 +432,6 @@ impl ModelLoad {
&self.model_and_vocabulary.model_path,
vocabulary_source,
params,
overrides,
|progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
Expand Down
44 changes: 17 additions & 27 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,31 @@ fn main() -> Result<()> {

let cli_args = Args::parse();
match &cli_args {
Args::Llama { args } => handle_args::<llm::models::Llama>(args, None),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args, None),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args, None),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args, None),
Args::GptNeoX { args } => handle_args::<llm::models::GptNeoX>(args, None),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args, None),
Args::Llama { args } => handle_args::<llm::models::Llama>(args),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::GptNeoX { args } => handle_args::<llm::models::GptNeoX>(args),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args),
}
}

fn handle_args<M: llm::KnownModel + 'static>(
args: &cli_args::BaseArgs,
overrides: Option<M::Overrides>,
) -> Result<()> {
fn handle_args<M: llm::KnownModel + 'static>(args: &cli_args::BaseArgs) -> Result<()> {
match args {
BaseArgs::Infer(args) => infer::<M>(args, overrides),
BaseArgs::Perplexity(args) => perplexity::<M>(args, overrides),
BaseArgs::Infer(args) => infer::<M>(args),
BaseArgs::Perplexity(args) => perplexity::<M>(args),
BaseArgs::Info(args) => info::<M>(args),
BaseArgs::PromptTokens(args) => prompt_tokens::<M>(args),
BaseArgs::Repl(args) => interactive::<M>(args, overrides, false),
BaseArgs::Chat(args) => interactive::<M>(args, overrides, true),
BaseArgs::Repl(args) => interactive::<M>(args, false),
BaseArgs::Chat(args) => interactive::<M>(args, true),
BaseArgs::Quantize(args) => quantize::<M>(args),
}
}

fn infer<M: llm::KnownModel + 'static>(
args: &cli_args::Infer,
overrides: Option<M::Overrides>,
) -> Result<()> {
fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>(overrides)?;
let model = args.model_load.load::<M>()?;

let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
Expand Down Expand Up @@ -120,13 +114,10 @@ fn infer<M: llm::KnownModel + 'static>(
Ok(())
}

fn perplexity<M: llm::KnownModel + 'static>(
args: &cli_args::Perplexity,
overrides: Option<M::Overrides>,
) -> Result<()> {
fn perplexity<M: llm::KnownModel + 'static>(args: &cli_args::Perplexity) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>(overrides)?;
let model = args.model_load.load::<M>()?;
let (mut session, _) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down Expand Up @@ -191,7 +182,7 @@ fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {

fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let model = args.model_load.load::<M>(None)?;
let model = args.model_load.load::<M>()?;
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
Expand Down Expand Up @@ -220,14 +211,13 @@ fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) ->

fn interactive<M: llm::KnownModel + 'static>(
args: &cli_args::Repl,
overrides: Option<M::Overrides>,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
chat_mode: bool,
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>(overrides)?;
let model = args.model_load.load::<M>()?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down
5 changes: 1 addition & 4 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ pub use loader::{
};
pub use lora::{LoraAdapter, LoraParameters};
pub use memmap2::Mmap;
pub use model::{
Hyperparameters, KnownModel, Model, ModelDynamicOverrideValue, ModelDynamicOverrides,
ModelParameters, OutputRequest,
};
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use regex::Regex;
pub use util::TokenUtf8Buffer;
Expand Down
3 changes: 1 addition & 2 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ pub fn load<M: KnownModel>(
path: &Path,
vocabulary_source: VocabularySource,
params: ModelParameters,
overrides: Option<M::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<M, LoadError> {
if !path.exists() {
Expand Down Expand Up @@ -492,7 +491,7 @@ pub fn load<M: KnownModel>(
loaded_tensors: Default::default(),
};

let model = KnownModel::new(hyperparameters, params, overrides, vocabulary, tl)?;
let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;

(load_progress_callback)(LoadProgress::Loaded {
file_size,
Expand Down
96 changes: 1 addition & 95 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
//! Large language model traits and types
use std::{
collections::HashMap,
error::Error,
fmt::Debug,
io::{BufRead, Write},
path::{Path, PathBuf},
};

use regex::Regex;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;

use crate::{
Expand All @@ -20,124 +18,32 @@ use crate::{
/// Common functions for model evaluation
pub mod common;

macro_rules! define_model_dynamic_override_value {
($(($name:ident, $type:ty, $doc:literal)),*) => {
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
#[serde(untagged)]
/// Valid value types for dynamic model overrides.
pub enum ModelDynamicOverrideValue {
$(#[doc=$doc] $name($type),)*
}

$(
impl TryFrom<ModelDynamicOverrideValue> for $type {
type Error = ();

fn try_from(value: ModelDynamicOverrideValue) -> Result<Self, Self::Error> {
match value {
ModelDynamicOverrideValue::$name(value) => Ok(value),
_ => Err(()),
}
}
}

impl From<$type> for ModelDynamicOverrideValue {
fn from(value: $type) -> Self {
Self::$name(value)
}
}
)*
};
}

define_model_dynamic_override_value!(
(Bool, bool, "A boolean value"),
(String, String, "A string value"),
(Int, i64, "An integer value"),
(Float, f64, "A float value")
);

/// Model options that can be overridden by the user at runtime.
///
/// Each model has its own set of options that can be overridden.
/// However, the calling code may not know the type of the model
/// at compile time. This type is used to store the overrides
/// for a model in a generic way.
#[derive(Debug, PartialEq, Serialize, Deserialize, Default, Clone)]
#[serde(transparent)]
pub struct ModelDynamicOverrides(pub HashMap<String, ModelDynamicOverrideValue>);
impl ModelDynamicOverrides {
/// Get the value of the override with the given `key`.
pub fn get<T: TryFrom<ModelDynamicOverrideValue>>(&self, key: &str) -> Option<T> {
self.0
.get(key)
.cloned()
.and_then(|value| T::try_from(value).ok())
}

/// Merge the overrides from `other` into this one.
pub fn merge(&mut self, other: impl Into<Self>) -> &mut Self {
self.0.extend(other.into().0.into_iter());
self
}

/// Insert a new override with the given `key` and `value`.
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<ModelDynamicOverrideValue>) {
self.0.insert(key.into(), value.into());
}
}
impl From<ModelDynamicOverrides> for () {
fn from(_: ModelDynamicOverrides) -> Self {}
}
impl From<()> for ModelDynamicOverrides {
fn from(_: ()) -> Self {
Self::default()
}
}

/// Interfaces for creating and interacting with a large language model with a known type
/// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)).
pub trait KnownModel: Send + Sync {
/// Hyperparameters for the model.
type Hyperparameters: Hyperparameters;

/// Model options that can be overridden by the user.
///
/// If there are no options to override, use `()`.
type Overrides: Serialize
+ DeserializeOwned
+ Default
+ From<ModelDynamicOverrides>
+ Into<ModelDynamicOverrides>;

/// Load this model from the `path` and configure it per the `params`. The status
/// of the loading process will be reported through `load_progress_callback`. This
/// is a helper function on top of [llm_base::load](crate::load).
fn load(
path: &Path,
vocabulary_source: VocabularySource,
params: ModelParameters,
overrides: Option<Self::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<Self, LoadError>
where
Self: Sized,
{
crate::load(
path,
vocabulary_source,
params,
overrides,
load_progress_callback,
)
crate::load(path, vocabulary_source, params, load_progress_callback)
}

/// Creates a new model from the provided [ModelParameters] hyperparameters.
/// This function is called by the [load](crate::loader::load) function.
fn new<E: Error>(
hyperparameters: Self::Hyperparameters,
params: ModelParameters,
overrides: Option<Self::Overrides>,
vocabulary: Vocabulary,
tensor_loader: impl TensorLoader<E>,
) -> Result<Self, E>
Expand Down
1 change: 0 additions & 1 deletion crates/llm/examples/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ fn main() {
&path,
vocabulary_source,
model_params,
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
Expand Down
1 change: 0 additions & 1 deletion crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ fn main() {
&path,
vocabulary_source,
Default::default(),
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
Expand Down
1 change: 0 additions & 1 deletion crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ fn main() {
&path,
vocabulary_source,
Default::default(),
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
Expand Down
Loading

0 comments on commit 8f62b39

Please sign in to comment.