Skip to content

Commit

Permalink
GPT-J model implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 2, 2023
1 parent d33ed84 commit 0b87983
Show file tree
Hide file tree
Showing 10 changed files with 623 additions and 5 deletions.
20 changes: 19 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@
"args": ["${env:HOME}/.ggml-models/cerebras-gpt-13b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug GPT-J Inference",
"cargo": {
"args": [
"build",
"--example=gptj-inference",
"--package=llm-gptj"
],
"filter": {
"name": "gptj-inference",
"kind": "example"
}
},
"args": ["${env:HOME}/.ggml-models/gpt-j-6b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
Expand All @@ -57,7 +75,7 @@
"kind": "example"
}
},
"args": ["${env:HOME}/.ggml-models/stablelm-base-alpha-3b-f16.bin"],
"args": ["${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"],
"cwd": "${workspaceFolder}"
}
]
Expand Down
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ pub enum Args {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-J model
#[clap(id = "gptj")]
GptJ {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-NeoX model
#[clap(id = "neox")]
NeoX {
Expand Down
1 change: 1 addition & 0 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ fn main() -> Result<()> {
Args::Llama { args } => handle_args::<llm::models::Llama>(args),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::NeoX { args } => handle_args::<llm::models::NeoX>(args),
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ edition = "2021"
llm-base = { path = "../llm-base" }
llm-llama = { path = "../models/llama", features = ["convert"], optional = true }
llm-gpt2 = { path = "../models/gpt2", optional = true }
llm-gptj = { path = "../models/gptj", optional = true }
llm-bloom = { path = "../models/bloom", optional = true }
llm-neox = { path = "../models/neox", optional = true }

[features]
default = ["llama", "gpt2", "bloom", "neox"]
default = ["llama", "gpt2", "gptj", "bloom", "neox"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
neox = ["dep:llm-neox"]
2 changes: 2 additions & 0 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub mod models {
pub use llm_bloom::{self as bloom, Bloom};
#[cfg(feature = "gpt2")]
pub use llm_gpt2::{self as gpt2, Gpt2};
#[cfg(feature = "gptj")]
pub use llm_gptj::{self as gptj, GptJ};
#[cfg(feature = "llama")]
pub use llm_llama::{self as llama, Llama};
#[cfg(feature = "neox")]
Expand Down
15 changes: 15 additions & 0 deletions crates/models/gptj/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "llm-gptj"
version = { workspace = true }
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llm-base = { path = "../../llm-base" }
ggml = { path = "../../ggml" }

bytemuck = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
40 changes: 40 additions & 0 deletions crates/models/gptj/examples/gptj-inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::{convert::Infallible, env::args, io::Write, path::Path};

use llm_base::{load_progress_callback_stdout, KnownModel};

fn main() {
let args: Vec<String> = args().collect();
let loc = &args[1];
let prompt = match &args.len() {
3 => &args[2],
_ => "Rust is a cool programming language because ",
};

println!(" >>> Loading model from {loc}...");
let now = std::time::Instant::now();

let gptj = llm_gptj::GptJ::load(Path::new(loc), true, 512, load_progress_callback_stdout)
.unwrap_or_else(|e| panic!("Error loading model from {loc}: {e}"));

println!(" >>> Model loaded in {} ms.", now.elapsed().as_millis());

let mut session = gptj.start_session(Default::default());
let res = session.inference_with_prompt::<Infallible>(
&gptj,
&Default::default(),
&Default::default(),
prompt,
&mut rand::thread_rng(),
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
},
);

match res {
Ok(result) => println!("\n\nInference stats:\n{result}"),
Err(err) => println!("\n{err}"),
}
}
Loading

0 comments on commit 0b87983

Please sign in to comment.