Skip to content

Commit

Permalink
refactor: implement ModelOverrides + RedPajama alias
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 14, 2023
1 parent 6de98e3 commit 08bc9f9
Show file tree
Hide file tree
Showing 19 changed files with 376 additions and 120 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"kind": "example"
}
},
"args": ["redpajama", "${env:HOME}/.ggml-models/redpajama-incite-7b.bin"],
"args": ["gptneox-redpajama", "${env:HOME}/.ggml-models/redpajama-incite-7b.bin"],
"cwd": "${workspaceFolder}"
},
{
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ log = "0.4"
rand = "0.8.5"
rustyline = { version = "11.0.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
thiserror = "1.0"

Expand Down
112 changes: 68 additions & 44 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ pub enum Args {
NeoX {
#[command(subcommand)]
args: BaseArgs,

#[arg(long)]
/// By default, the GPT-NeoX architecture uses a parallel residual.
///
/// This flag disables that, as some models out there are trained without it,
/// and the model format does not store this information.
no_parallel_residual: bool,
},
/// Use a GPT-NeoX model with RedPajama's modifications
///
/// (GPT-NeoX with `use_parallel_residual` set to false)
#[clap(id = "gptneox-redpajama")]
GptNeoXRedPajama {
#[command(subcommand)]
args: BaseArgs,
},
}

Expand Down Expand Up @@ -312,7 +327,10 @@ pub struct ModelLoad {
pub no_mmap: bool,
}
impl ModelLoad {
pub fn load<M: llm::KnownModel + 'static>(&self) -> Result<Box<dyn Model>> {
pub fn load<M: llm::KnownModel + 'static>(
&self,
overrides: Option<M::Overrides>,
) -> Result<Box<dyn Model>> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
n_context_tokens: self.num_ctx_tokens,
Expand All @@ -327,49 +345,55 @@ impl ModelLoad {
let now = std::time::Instant::now();
let mut prev_load_time = now;

let model = llm::load::<M>(&self.model_path, params, move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
})
.wrap_err("Could not load model")?;
let model =
llm::load::<M>(
&self.model_path,
params,
overrides,
move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
},
)
.wrap_err("Could not load model")?;

Ok(Box::new(model))
}
Expand Down
47 changes: 34 additions & 13 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,49 @@ fn main() -> Result<()> {

let cli_args = Args::parse();
match &cli_args {
Args::Llama { args } => handle_args::<llm::models::Llama>(args),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::NeoX { args } => handle_args::<llm::models::NeoX>(args),
Args::Llama { args } => handle_args::<llm::models::Llama>(args, None),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args, None),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args, None),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args, None),
Args::NeoX {
args,
no_parallel_residual,
} => handle_args::<llm::models::NeoX>(
args,
Some(llm::models::NeoXOverrides {
use_parallel_residual: !*no_parallel_residual,
}),
),
Args::GptNeoXRedPajama { args } => handle_args::<llm::models::NeoX>(
args,
Some(llm::models::NeoXOverrides {
use_parallel_residual: false,
}),
),
}
}

fn handle_args<M: llm::KnownModel + 'static>(args: &cli_args::BaseArgs) -> Result<()> {
fn handle_args<M: llm::KnownModel + 'static>(
args: &cli_args::BaseArgs,
overrides: Option<M::Overrides>,
) -> Result<()> {
match args {
BaseArgs::Infer(args) => infer::<M>(args),
BaseArgs::Infer(args) => infer::<M>(args, overrides),
BaseArgs::Info(args) => info::<M>(args),
BaseArgs::PromptTokens(args) => prompt_tokens::<M>(args),
BaseArgs::Repl(args) => interactive::<M>(args, false),
BaseArgs::Chat(args) => interactive::<M>(args, true),
BaseArgs::Repl(args) => interactive::<M>(args, overrides, false),
BaseArgs::Chat(args) => interactive::<M>(args, overrides, true),
BaseArgs::Quantize(args) => quantize::<M>(args),
}
}

fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
fn infer<M: llm::KnownModel + 'static>(
args: &cli_args::Infer,
overrides: Option<M::Overrides>,
) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(overrides)?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
args.persist_session.as_deref(),
Expand Down Expand Up @@ -144,7 +164,7 @@ fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {

fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(None)?;
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
Expand Down Expand Up @@ -173,13 +193,14 @@ fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) ->

fn interactive<M: llm::KnownModel + 'static>(
args: &cli_args::Repl,
overrides: Option<M::Overrides>,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
chat_mode: bool,
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(overrides)?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down
5 changes: 4 additions & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ pub use loader::{
TensorLoader,
};
pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use model::{
Hyperparameters, KnownModel, Model, ModelDynamicOverrideValue, ModelDynamicOverrides,
ModelParameters, OutputRequest,
};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use util::TokenUtf8Buffer;
pub use vocabulary::{InvalidTokenBias, TokenBias, TokenId, Vocabulary};
Expand Down
3 changes: 2 additions & 1 deletion crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ pub trait TensorLoader<E: std::error::Error> {
pub fn load<M: KnownModel>(
path: &Path,
params: ModelParameters,
overrides: Option<M::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<M, LoadError> {
let paths = util::find_all_model_files(path)?;
Expand Down Expand Up @@ -447,7 +448,7 @@ pub fn load<M: KnownModel>(
loaded_tensors: Default::default(),
};

let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;
let model = KnownModel::new(hyperparameters, params, overrides, vocabulary, tl)?;

(load_progress_callback)(LoadProgress::Loaded {
file_size,
Expand Down
Loading

0 comments on commit 08bc9f9

Please sign in to comment.