diff --git a/.vscode/launch.json b/.vscode/launch.json index 34892f3b..e4921893 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -32,6 +32,24 @@ "args": ["${env:HOME}/.ggml-models/cerebras-gpt-13b.bin"], "cwd": "${workspaceFolder}" }, + { + "type": "lldb", + "request": "launch", + "name": "Debug GPT-J Inference", + "cargo": { + "args": [ + "build", + "--example=gptj-inference", + "--package=llm-gptj" + ], + "filter": { + "name": "gptj-inference", + "kind": "example" + } + }, + "args": ["${env:HOME}/.ggml-models/gpt-j-6b.bin"], + "cwd": "${workspaceFolder}" + }, { "type": "lldb", "request": "launch", @@ -57,7 +75,7 @@ "kind": "example" } }, - "args": ["${env:HOME}/.ggml-models/stablelm-base-alpha-3b-f16.bin"], + "args": ["${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"], "cwd": "${workspaceFolder}" } ] diff --git a/Cargo.lock b/Cargo.lock index 89a3dbde..498c8baa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -607,6 +607,7 @@ dependencies = [ "llm-base", "llm-bloom", "llm-gpt2", + "llm-gptj", "llm-llama", "llm-neox", ] @@ -663,6 +664,16 @@ dependencies = [ "rand", ] +[[package]] +name = "llm-gptj" +version = "0.1.0" +dependencies = [ + "bytemuck", + "ggml", + "llm-base", + "rand", +] + [[package]] name = "llm-llama" version = "0.1.0" diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 6ba24a0a..bd19f38f 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -26,6 +26,12 @@ pub enum Args { #[command(subcommand)] args: BaseArgs, }, + /// Use a GPT-J model + #[clap(id = "gptj")] + GptJ { + #[command(subcommand)] + args: BaseArgs, + }, /// Use a GPT-NeoX model #[clap(id = "neox")] NeoX { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index e630dc32..1f70e4f6 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -25,6 +25,7 @@ fn main() -> Result<()> { Args::Llama { args } => handle_args::(args), Args::Bloom { args } => handle_args::(args), Args::Gpt2 { args } => handle_args::(args), + Args::GptJ { args } => handle_args::(args), Args::NeoX { args } => handle_args::(args), } } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 8268344b..4dd9c338 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -7,12 +7,14 @@ edition = "2021" llm-base = { path = "../llm-base" } llm-llama = { path = "../models/llama", features = ["convert"], optional = true } llm-gpt2 = { path = "../models/gpt2", optional = true } +llm-gptj = { path = "../models/gptj", optional = true } llm-bloom = { path = "../models/bloom", optional = true } llm-neox = { path = "../models/neox", optional = true } [features] -default = ["llama", "gpt2", "bloom", "neox"] +default = ["llama", "gpt2", "gptj", "bloom", "neox"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] +gptj = ["dep:llm-gptj"] bloom = ["dep:llm-bloom"] neox = ["dep:llm-neox"] diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index e5a7d849..3ea46ee4 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -20,6 +20,8 @@ pub mod models { pub use llm_bloom::{self as bloom, Bloom}; #[cfg(feature = "gpt2")] pub use llm_gpt2::{self as gpt2, Gpt2}; + #[cfg(feature = "gptj")] + pub use llm_gptj::{self as gptj, GptJ}; #[cfg(feature = "llama")] pub use llm_llama::{self as llama, Llama}; #[cfg(feature = "neox")] diff --git a/crates/models/gptj/Cargo.toml b/crates/models/gptj/Cargo.toml new file mode 100644 index 00000000..c94ff855 --- /dev/null +++ b/crates/models/gptj/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "llm-gptj" +version = { workspace = true } +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +llm-base = { path = "../../llm-base" } +ggml = { path = "../../ggml" } + +bytemuck = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/crates/models/gptj/examples/gptj-inference.rs b/crates/models/gptj/examples/gptj-inference.rs new file mode 100644 index 00000000..5476ec03 --- /dev/null +++ b/crates/models/gptj/examples/gptj-inference.rs @@ -0,0 +1,40 @@ +use std::{convert::Infallible, env::args, io::Write, path::Path}; + +use llm_base::{load_progress_callback_stdout, KnownModel}; + +fn main() { + let args: Vec = args().collect(); + let loc = &args[1]; + let prompt = match &args.len() { + 3 => &args[2], + _ => "Rust is a cool programming language because ", + }; + + println!(" >>> Loading model from {loc}..."); + let now = std::time::Instant::now(); + + let gptj = llm_gptj::GptJ::load(Path::new(loc), true, 512, load_progress_callback_stdout) + .unwrap_or_else(|e| panic!("Error loading model from {loc}: {e}")); + + println!(" >>> Model loaded in {} ms.", now.elapsed().as_millis()); + + let mut session = gptj.start_session(Default::default()); + let res = session.inference_with_prompt::( + &gptj, + &Default::default(), + &Default::default(), + prompt, + &mut rand::thread_rng(), + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + match res { + Ok(result) => println!("\n\nInference stats:\n{result}"), + Err(err) => println!("\n{err}"), + } +} diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs new file mode 100644 index 00000000..09e3be6a --- /dev/null +++ b/crates/models/gptj/src/lib.rs @@ -0,0 +1,523 @@ +// Ref: https://github.com/ggerganov/ggml/blob/abea4b7/examples/gpt-j/main.cpp + +use std::{error::Error, path::Path}; + +use ggml::Tensor; +use llm_base::{ + util, BasicWriteError, EvaluateOutputRequest, FileType, InferenceParameters, InferenceSession, + InferenceSessionParameters, KnownModel, LoadError, LoadProgress, Mmap, TensorLoader, TokenId, + Vocabulary, +}; + +pub struct GptJ { + hyperparameters: Hyperparameters, + n_context_tokens: usize, + + vocabulary: Vocabulary, + + // normalization + ln_f_g: Tensor, + ln_f_b: Tensor, + + // position embedding + wte: Tensor, + + // language model head & bias + lmh_g: Tensor, + lmh_b: Tensor, + + layers: Vec, + + /// Needs to kept alive while the model is alive + _mmap: Option, + + // Must be kept alive for the model + _context: ggml::Context, +} + +unsafe impl Send for GptJ {} +unsafe impl Sync for GptJ {} + +impl GptJ { + /// Load the model from `path` with `n_context_tokens` context tokens. + /// + /// The status of the loading process will be reported through `load_progress_callback`. + pub fn load( + path: &Path, + prefer_mmap: bool, + n_context_tokens: usize, + load_progress_callback: impl FnMut(LoadProgress), + ) -> Result { + llm_base::load(path, prefer_mmap, n_context_tokens, load_progress_callback) + } +} + +impl KnownModel for GptJ { + type Hyperparameters = Hyperparameters; + + fn new( + hyperparameters: Self::Hyperparameters, + n_context_tokens: usize, + vocabulary: Vocabulary, + tensor_loader: impl TensorLoader, + ) -> Result + where + Self: Sized, + { + let n_embd = hyperparameters.n_embd; + let n_layer = hyperparameters.n_layer; + let n_vocab = hyperparameters.n_vocab; + + let mut tl = tensor_loader; + + // prepare memory for weights + let wte = tl.load("transformer.wte.weight", &[n_embd, n_vocab])?; + let ln_f_g = tl.load("transformer.ln_f.weight", &[n_embd])?; + let ln_f_b = tl.load("transformer.ln_f.bias", &[n_embd])?; + let lmh_g = tl.load("lm_head.weight", &[n_embd, n_vocab])?; + let lmh_b = tl.load("lm_head.bias", &[n_vocab])?; + + let mut layers = Vec::new(); + for i in 0..n_layer { + let layer = Layer { + ln_1_g: tl.load(&format!("transformer.h.{i}.ln_1.weight"), &[n_embd])?, + ln_1_b: tl.load(&format!("transformer.h.{i}.ln_1.bias"), &[n_embd])?, + c_attn_q_proj_w: tl.load( + &format!("transformer.h.{i}.attn.q_proj.weight"), + &[n_embd, n_embd], + )?, + c_attn_k_proj_w: tl.load( + &format!("transformer.h.{i}.attn.k_proj.weight"), + &[n_embd, n_embd], + )?, + c_attn_v_proj_w: tl.load( + &format!("transformer.h.{i}.attn.v_proj.weight"), + &[n_embd, n_embd], + )?, + c_attn_proj_w: tl.load( + &format!("transformer.h.{i}.attn.out_proj.weight"), + &[n_embd, n_embd], + )?, + c_mlp_fc_w: tl.load( + &format!("transformer.h.{i}.mlp.fc_in.weight"), + &[n_embd, n_embd * 4], + )?, + c_mlp_fc_b: tl.load(&format!("transformer.h.{i}.mlp.fc_in.bias"), &[n_embd * 4])?, + c_mlp_proj_w: tl.load( + &format!("transformer.h.{i}.mlp.fc_out.weight"), + &[n_embd * 4, n_embd], + )?, + c_mlp_proj_b: tl.load(&format!("transformer.h.{i}.mlp.fc_out.bias"), &[n_embd])?, + }; + + layers.push(layer); + } + + let (_context, _, _mmap) = tl.finish(); + + Ok(GptJ { + hyperparameters, + n_context_tokens, + vocabulary, + ln_f_g, + ln_f_b, + wte, + lmh_g, + lmh_b, + layers, + _mmap, + _context, + }) + } + + fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { + InferenceSession::new( + params, + self.hyperparameters.n_ctx, + self.hyperparameters.n_layer, + self.hyperparameters.n_embd, + self.hyperparameters.n_vocab, + ) + } + + fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, + ) { + let n = input_tokens.len(); + let n_threads = params.n_threads; + + let Hyperparameters { + n_embd, + n_head, + n_vocab, + n_layer, + n_rot, + .. + } = self.hyperparameters; + let n_ctx = self.n_context_tokens; + + // For the first run, we need to guess a maximum buffer size so we can measure + // the actual memory consumption of the temporary ggml context. + // + // These numbers are from `llama.cpp`, and could potentially be more efficient. + let mut buf_size = { + let buf_size_mb = if n_layer >= 80 { + 1536 + } else if n_layer >= 60 { + 1280 + } else { + 1024 + }; + buf_size_mb * 1024 * 1024 + }; + if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { + // add 10% to account for ggml object overhead + buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; + }; + let ctx0 = ggml::Context::init(buf_size, true); + + let mut gf = ggml::ComputationGraph::new(n_threads); + + let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n); + unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; + + let n_past = session.n_past; + + // wte + let mut input_layer = ctx0.op_get_rows(&self.wte, &embd); + + let memory_k = &session.memory_k; + let memory_k_size = memory_k.element_size(); + + let memory_v = &session.memory_v; + let memory_v_size = memory_v.element_size(); + + for il in 0..n_layer { + // norm + let mut current = ctx0.op_norm(&input_layer); + current = ctx0.op_add( + &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, ¤t), ¤t), + &ctx0.op_repeat(&self.layers[il].ln_1_b, ¤t), + ); + + let input_sa = current.share(); + + // self-attention + let qcur = ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_mul_mat(&self.layers[il].c_attn_q_proj_w, ¤t), + n_embd / n_head, + n_head, + n, + ), + n_past, + n_rot, + 0, + ); + let kcur = ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_mul_mat(&self.layers[il].c_attn_k_proj_w, ¤t), + n_embd / n_head, + n_head, + n, + ), + n_past, + n_rot, + 0, + ); + + // self-attention store key and value to memory + let vcur = + ctx0.op_transpose(&ctx0.op_mul_mat(&self.layers[il].c_attn_v_proj_w, ¤t)); + + let k = ctx0.op_view_1d( + memory_k, + n * n_embd, + (memory_k_size * n_embd) * (il * n_ctx + n_past), + ); + let v = ctx0.op_view_2d( + memory_v, + (n, n_embd), + n_ctx * memory_v_size, + (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size, + ); + + gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + + let q = ctx0.op_permute(&qcur, 0, 2, 1, 3); + let big_k = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + memory_k, + (n_past + n) * n_embd, + il * n_ctx * memory_k_size * n_embd, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + 0, + 2, + 1, + 3, + ); + + let kq = ctx0.op_mul_mat(&big_k, &q); + let kq_scaled = ctx0.op_scale( + &kq, + &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past); + let kq_softmax = ctx0.op_soft_max(&kq_masked); + + let big_v = ctx0.op_view_3d( + memory_v, + (n_past + n, n_embd / n_head, n_head), + ( + n_ctx * memory_v_size, + n_ctx * memory_v_size * n_embd / n_head, + ), + il * n_ctx * memory_v_size * n_embd, + ); + + let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax); + let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3); + + current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); + + // self-attention projection + current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); + + // feed-forward + let ff_in = current.share(); + + current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa); + current = ctx0.op_add( + &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, ¤t), + ¤t, + ); + + current = ctx0.op_gelu(¤t); + + // feed-forward projection + current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t); + current = ctx0.op_add( + &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, ¤t), + ¤t, + ); + + current = ctx0.op_add(¤t, &ff_in); + + // input for next layer + input_layer = ctx0.op_add(¤t, &input_layer); + } + + // norm + input_layer = ctx0.op_norm(&input_layer); + input_layer = ctx0.op_add( + &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer), + &ctx0.op_repeat(&self.ln_f_b, &input_layer), + ); + + // lm_head + input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); + input_layer = ctx0.op_add(&ctx0.op_repeat(&self.lmh_b, &input_layer), &input_layer); + + gf.build_forward_expand(&input_layer); + // thread 'main' panicked at 'WeightedIndex error: InvalidWeight', llm-base/src/inference_session.rs:293:47 + ctx0.graph_compute(&mut gf); + + // return result for just the last token + // SAFETY: yolo + assert_eq!(session.last_logits.len(), n_vocab); + unsafe { + input_layer.read_data( + n_vocab * (n - 1) * std::mem::size_of::(), + bytemuck::cast_slice_mut(&mut session.last_logits), + ) + }; + + // Extract logits + if let Some(all_logits) = &mut output_request.all_logits { + all_logits.resize(n_vocab * n, 0.0); + // SAFETY: Tensor data can be read (properly aligned, initialized, + // data will not be mutated or otherwise aliased during the copy), + // and we're not reading past the end of the tensor data. + assert_eq!(input_layer.nelements(), n_vocab * n); + unsafe { + input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); + } + } + + // Extract embeddings + if let Some(embeddings) = &mut output_request.embeddings { + embeddings.resize(n_embd * n, 0.0); + // SAFETY: Same rationale as for the "Extract logits" section applies. + assert_eq!(embd.nelements(), n_embd * n); + unsafe { + embd.read_data(0, bytemuck::cast_slice_mut(embeddings)); + } + } + + // Adjust the required memory per token if we didn't know that already + if session.mem_per_token == 0 { + session.mem_per_token = ctx0.used_mem() / n; + } + + // Adjust n_past to new length. + session.n_past += input_tokens.len(); + } + + fn vocabulary(&self) -> &Vocabulary { + &self.vocabulary + } + + fn n_context_tokens(&self) -> usize { + self.hyperparameters.n_ctx + } + + fn eot_token_id(&self) -> TokenId { + self.vocabulary + .token_to_id + .get("<|endoftext|>".as_bytes()) + .copied() + .unwrap() + } +} + +/// The hyperparameters of the model. +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct Hyperparameters { + /// n_vocab + pub n_vocab: usize, + /// n_ctx + pub n_ctx: usize, + /// n_embd + pub n_embd: usize, + /// n_head + pub n_head: usize, + /// n_layer + pub n_layer: usize, + /// n_rot + pub n_rot: usize, + /// file_type + pub file_type: FileType, +} +impl llm_base::Hyperparameters for Hyperparameters { + type WriteError = BasicWriteError; + + fn read(reader: &mut dyn std::io::BufRead) -> Result { + let hyperparameters = Hyperparameters { + n_vocab: util::read_i32(reader)?.try_into()?, + n_ctx: util::read_i32(reader)?.try_into()?, + n_embd: util::read_i32(reader)?.try_into()?, + n_head: util::read_i32(reader)?.try_into()?, + n_layer: util::read_i32(reader)?.try_into()?, + n_rot: util::read_i32(reader)?.try_into()?, + file_type: { + let ftype = util::read_i32(reader)?; + FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))? + }, + }; + + let n_vocab = util::read_i32(reader)? as usize; + if hyperparameters.n_vocab != n_vocab { + return Err(LoadError::InvariantBroken { + path: None, + invariant: format!( + "GPT2 model expected n_vocab {} found {}", + hyperparameters.n_vocab, n_vocab + ), + }); + } + + Ok(hyperparameters) + } + + fn write(&self, writer: &mut dyn std::io::Write) -> Result<(), Self::WriteError> { + util::write_i32(writer, self.n_vocab.try_into()?)?; + util::write_i32(writer, self.n_ctx.try_into()?)?; + util::write_i32(writer, self.n_embd.try_into()?)?; + util::write_i32(writer, self.n_head.try_into()?)?; + util::write_i32(writer, self.n_layer.try_into()?)?; + util::write_i32(writer, self.n_rot.try_into()?)?; + util::write_i32(writer, self.file_type.into())?; + Ok(()) + } + + fn n_vocabulary(&self) -> usize { + self.n_vocab + } +} + +struct Layer { + // normalization + ln_1_g: Tensor, + ln_1_b: Tensor, + + // attention + c_attn_q_proj_w: Tensor, + c_attn_k_proj_w: Tensor, + c_attn_v_proj_w: Tensor, + + c_attn_proj_w: Tensor, + + // ff + c_mlp_fc_w: Tensor, + c_mlp_fc_b: Tensor, + + c_mlp_proj_w: Tensor, + c_mlp_proj_b: Tensor, +} + +#[cfg(test)] +impl GptJ { + /// This does *not* construct a valid model. All of the tensors are entirely + /// empty. However, it can be used to determine if some code will compile. + fn new_empty() -> Self { + let context = ggml::Context::init(1024 * 1024, true); + + Self { + hyperparameters: Default::default(), + n_context_tokens: 0, + vocabulary: Default::default(), + ln_f_g: context.new_f32(0.0), + ln_f_b: context.new_f32(0.0), + wte: context.new_f32(0.0), + lmh_g: context.new_f32(0.0), + lmh_b: context.new_f32(0.0), + layers: Default::default(), + _mmap: Default::default(), + _context: context, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn can_share_model_between_threads() { + let model = Arc::new(GptJ::new_empty()); + + for _ in 0..4 { + let model = model.clone(); + std::thread::spawn(move || { + let _session = model.start_session(Default::default()); + }); + } + + let session = model.start_session(Default::default()); + std::thread::spawn(move || { + let _session = session; + }); + } +} diff --git a/crates/models/neox/examples/neox-inference.rs b/crates/models/neox/examples/neox-inference.rs index da7cd92a..554d4777 100644 --- a/crates/models/neox/examples/neox-inference.rs +++ b/crates/models/neox/examples/neox-inference.rs @@ -13,14 +13,14 @@ fn main() { println!(" >>> Loading model from {loc}..."); let now = std::time::Instant::now(); - let gpt2 = llm_neox::NeoX::load(Path::new(loc), true, 512, load_progress_callback_stdout) + let neox = llm_neox::NeoX::load(Path::new(loc), true, 512, load_progress_callback_stdout) .unwrap_or_else(|e| panic!("Error loading model from {loc}: {e}")); println!(" >>> Model loaded in {} ms.", now.elapsed().as_millis()); - let mut session = gpt2.start_session(Default::default()); + let mut session = neox.start_session(Default::default()); let res = session.inference_with_prompt::( - &gpt2, + &neox, &Default::default(), &Default::default(), prompt,