Skip to content

Commit

Permalink
Add back BLOOM
Browse files Browse the repository at this point in the history
Co-authored-by: @hhamud <53880692+hhamud@users.noreply.github.com>
  • Loading branch information
danforbes and hhamud committed Apr 30, 2023
1 parent 608090b commit 78db42c
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 4 deletions.
18 changes: 18 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug example 'bloom_inference'",
"cargo": {
"args": [
"build",
"--example=bloom_inference",
"--package=bloom"
],
"filter": {
"name": "bloom_inference",
"kind": "example"
}
},
"args": ["${env:HOME}/.ggml-models/bloom-7b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
Expand Down
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[workspace]
members = [
# Crates
"bloom",
"ggml",
"llm-base",
"gpt2",
"llama",
"llm",
"llm-base",
"llm-cli",
]
resolver = "2"
Expand Down
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ quantized versions of the model.

Make sure you have a Rust 1.65.0 or above and C toolchain[^1] set up.

`llm-base`, `gpt2`, and `llama` are Rust libraries, while `llm-cli` is a CLI
applications that wraps `gpt2` and `llama` and offer basic inference
capabilities.
`llm-base`, and the model crates (e.g. `bloom`, `gpt2` `llama`) are Rust
libraries, while `llm-cli` is a CLI applications that wraps the models and offer
basic inference capabilities.

The following instructions explain how to build CLI applications.

Expand Down Expand Up @@ -103,6 +103,14 @@ cargo run -p llama-cli quantize /path/to/your/models/7B/ggml-model-f16.bin /path
> The [llama.cpp repository](https://github.com/ggerganov/llama.cpp) has
> additional information on how to obtain and run specific models.
### BLOOM

The open-source [BLOOM](https://bigscience.huggingface.co/blog/bloom) model is
also supported.
[More information](https://huggingface.co/docs/transformers/model_doc/bloom)
about BLOOM is available on HuggingFace, as are some
[quantized models](https://huggingface.co/models?search=bloom%20ggml).

### GPT2

OpenAI's [GPT-2](https://jalammar.github.io/illustrated-gpt2/) architecture is
Expand Down
15 changes: 15 additions & 0 deletions bloom/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "bloom"
version = { workspace = true }
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ggml = { path = "../ggml" }
llm-base = { path = "../llm-base" }

bytemuck = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
42 changes: 42 additions & 0 deletions bloom/examples/bloom_inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::{convert::Infallible, env::args, io::Write};

use llm_base::{load_progress_callback, model::KnownModel};

extern crate bloom;

fn main() {
let args: Vec<String> = args().collect();
let loc = &args[1];
let prompt = match &args.len() {
3 => &args[2],
_ => "Rust is a cool programming language because ",
};

println!(" >>> Loading model from {loc}...");
let now = std::time::Instant::now();

let bloom = bloom::Bloom::load(loc, true, 512, load_progress_callback)
.unwrap_or_else(|e| panic!("Error loading model from {loc}: {e}"));

println!(" >>> Model loaded in {} ms.", now.elapsed().as_millis());

let mut session = bloom.start_session(Default::default());
let res = session.inference_with_prompt::<Infallible>(
&bloom,
&Default::default(),
prompt,
None,
&mut rand::thread_rng(),
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
},
);

match res {
Ok(result) => println!("\n\nInference stats:\n{result}"),
Err(err) => println!("\n{err}"),
}
}
Loading

0 comments on commit 78db42c

Please sign in to comment.