Skip to content

Commit

Permalink
Add a chat example that uses inference feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 11, 2023
1 parent 3b47302 commit 5502f08
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 23 deletions.
18 changes: 16 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,29 @@
{
"type": "lldb",
"request": "launch",
"name": "Debug NeoX Inference",
"name": "Debug GPT-NeoX Inference",
"cargo": {
"args": ["build", "--example=inference", "--package=llm"],
"filter": {
"name": "inference",
"kind": "example"
}
},
"args": ["neox", "${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"],
"args": ["gptneox", "${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug Vicuna Chat",
"cargo": {
"args": ["build", "--example=vicuna-chat", "--package=llm"],
"filter": {
"name": "vicuna-chat",
"kind": "example"
}
},
"args": ["llama", "${env:HOME}/.ggml-models/wizardlm-7b.bin"],
"cwd": "${workspaceFolder}"
}
]
Expand Down
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ license = "MIT OR Apache-2.0"

[workspace.dependencies]
bytemuck = "1.13.1"
bytesize = "1.1"
log = "0.4"
rand = "0.8.5"
rustyline = { version = "11.0.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
thiserror = "1.0"

# Config for 'cargo dist'
Expand Down
6 changes: 3 additions & 3 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ path = "src/main.rs"
[dependencies]
llm = { path = "../../crates/llm", version = "0.1.1" }

bytesize = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
rustyline = { workspace = true }
spinoff = { workspace = true }

bincode = "1.3.3"
bytesize = "1.1"
env_logger = "0.10.0"
num_cpus = "1.15.0"
rustyline = { version = "11.0.0", features = ["derive"] }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }

clap = { version = "4.1.8", features = ["derive"] }
color-eyre = { version = "0.6.2", default-features = false }
Expand Down
20 changes: 2 additions & 18 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,7 @@ impl InferenceSession {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
InferenceFeedback::Continue => (),
InferenceFeedback::FeedPrompt(_) => {
return Err(InferenceError::UserCallback(Some(
"Cannot interrupt prompt ingestion".into(),
)))
}
InferenceFeedback::Halt => {
return Err(InferenceError::UserCallback(None))
}
InferenceFeedback::Halt => break,
},
}
}
Expand Down Expand Up @@ -219,14 +212,7 @@ impl InferenceSession {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
InferenceFeedback::Continue => (),
InferenceFeedback::FeedPrompt(p) => self.feed_prompt(
model,
parameters,
&p,
output_request,
TokenUtf8Buffer::adapt_callback(&mut callback),
)?,
InferenceFeedback::Halt => return Err(InferenceError::UserCallback(None)),
InferenceFeedback::Halt => break,
},
}
}
Expand Down Expand Up @@ -644,8 +630,6 @@ pub enum InferenceResponse {
pub enum InferenceFeedback {
/// Continue inference
Continue,
/// Feed the provided text into the inference session
FeedPrompt(String),
/// Halt inference
Halt,
}
Expand Down
4 changes: 4 additions & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ llm-neox = { path = "../models/neox", optional = true, version = "0.1.1" }
serde = { workspace = true }

[dev-dependencies]
bytesize = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
rustyline = { workspace = true }
spinoff = { workspace = true }

[features]
default = ["llama", "gpt2", "gptj", "bloom", "neox"]
Expand Down
190 changes: 190 additions & 0 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
use llm_base::{
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceStats, LoadProgress,
TokenUtf8Buffer,
};
use rustyline::error::ReadlineError;
use spinoff::{spinners::Dots2, Spinner};
use std::{convert::Infallible, env::args, io::Write, path::Path, time::Instant};

fn main() {
let raw_args: Vec<String> = args().collect();
let args = match &raw_args.len() {
3 => (raw_args[1].as_str(), raw_args[2].as_str()),
_ => {
panic!("Usage: cargo run --release --example vicuna-chat <model type> <path to model>")
}
};

let model_type = args.0;
let model_path = Path::new(args.1);

let architecture = model_type.parse().unwrap_or_else(|e| panic!("{e}"));

let sp = Some(Spinner::new(Dots2, "Loading model...", None));

let now = Instant::now();
let prev_load_time = now;

let model = llm::load_dynamic(
architecture,
model_path,
Default::default(),
load_progress_callback(sp, now, prev_load_time),
)
.unwrap_or_else(|err| panic!("Failed to load {model_type} model from {model_path:?}: {err}"));

let mut session = model.start_session(Default::default());

let character_name = "### Assistant";
let user_name = "### Human";
let persona = "A chat between a human and an assistant.";
let history = format!(
"{character_name}: Hello - How may I help you today?\n\
{user_name}: What is the capital or France?\n\
{character_name}: Paris is the capital of France."
);

session
.feed_prompt(
model.as_ref(),
&Default::default(),
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
TokenUtf8Buffer::adapt_callback(prompt_callback),
)
.expect("Failed to ingest initial prompt.");

let mut rl = rustyline::DefaultEditor::new().expect("Failed to create input reader");

let mut rng = rand::thread_rng();
let mut res = InferenceStats::default();
let mut buf = String::new();

loop {
println!();
let readline = rl.readline(format!("{user_name}: ").as_str());
print!("{character_name}:");
match readline {
Ok(line) => {
let stats = session
.infer(
model.as_ref(),
&mut rng,
&InferenceRequest {
prompt: format!("{user_name}: {line}\n{character_name}:").as_str(),
..Default::default()
},
&mut Default::default(),
inference_callback(String::from(user_name), &mut buf),
)
.unwrap_or_else(|e| panic!("{e}"));

res.feed_prompt_duration = res
.feed_prompt_duration
.saturating_add(stats.feed_prompt_duration);
res.prompt_tokens += stats.prompt_tokens;
res.predict_duration = res.predict_duration.saturating_add(stats.predict_duration);
res.predict_tokens += stats.predict_tokens;
}
Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => {
break;
}
Err(err) => {
println!("{err}");
}
}
}

println!("\n\nInference stats:\n{res}");
}

fn load_progress_callback(
mut sp: Option<Spinner>,
now: Instant,
mut prev_load_time: Instant,
) -> impl FnMut(LoadProgress) {
move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
}
}

fn prompt_callback(resp: InferenceResponse) -> Result<InferenceFeedback, Infallible> {
match resp {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => print_token(t),
_ => Ok(InferenceFeedback::Continue),
}
}

#[allow(clippy::needless_lifetimes)]
fn inference_callback<'a>(
stop_sequence: String,
buf: &'a mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + 'a {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t)
} else {
print_token(reverse_buf)
}
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}

fn print_token(t: String) -> Result<InferenceFeedback, Infallible> {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
}

0 comments on commit 5502f08

Please sign in to comment.