Skip to content

Commit

Permalink
refactor(llm): simplify vicuna-chat example
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 24, 2023
1 parent 421bd49 commit 7c35e4c
Showing 1 changed file with 14 additions and 80 deletions.
94 changes: 14 additions & 80 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use llm::{
InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceStats,
LoadProgress, ModelArchitecture,
load_progress_callback_stdout, InferenceFeedback, InferenceParameters, InferenceRequest,
InferenceResponse, InferenceStats, ModelArchitecture,
};
use rustyline::error::ReadlineError;
use spinoff::{spinners::Dots2, Spinner};
use std::{convert::Infallible, io::Write, path::Path, time::Instant};
use std::{convert::Infallible, io::Write, path::Path};

fn main() {
let raw_args: Vec<String> = std::env::args().skip(1).collect();
Expand All @@ -16,17 +15,13 @@ fn main() {
let model_architecture: ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let overrides = raw_args.get(2).map(|s| serde_json::from_str(s).unwrap());
let sp = Some(Spinner::new(Dots2, "Loading model...", None));

let now = Instant::now();
let prev_load_time = now;

let model = llm::load_dynamic(
model_architecture,
model_path,
Default::default(),
overrides,
load_progress_callback(sp, now, prev_load_time),
load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
Expand All @@ -39,7 +34,7 @@ fn main() {
let persona = "A chat between a human and an assistant.";
let history = format!(
"{character_name}: Hello - How may I help you today?\n\
{user_name}: What is the capital or France?\n\
{user_name}: What is the capital of France?\n\
{character_name}: Paris is the capital of France."
);

Expand All @@ -51,7 +46,12 @@ fn main() {
&inference_parameters,
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
llm::feed_prompt_callback(prompt_callback),
llm::feed_prompt_callback(|resp| match resp {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print_token(t)
}
_ => Ok(InferenceFeedback::Continue),
}),
)
.expect("Failed to ingest initial prompt.");

Expand Down Expand Up @@ -103,76 +103,10 @@ fn main() {
println!("\n\nInference stats:\n{res}");
}

fn load_progress_callback(
mut sp: Option<Spinner>,
now: Instant,
mut prev_load_time: Instant,
) -> impl FnMut(LoadProgress) {
move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::LoraApplied { name, source } => {
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Applied LoRA: {} from '{}'",
name,
source.file_name().unwrap().to_str().unwrap()
));
};
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
}
}

fn prompt_callback(resp: InferenceResponse) -> Result<InferenceFeedback, Infallible> {
match resp {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => print_token(t),
_ => Ok(InferenceFeedback::Continue),
}
}

#[allow(clippy::needless_lifetimes)]
fn inference_callback<'a>(
fn inference_callback(
stop_sequence: String,
buf: &'a mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + 'a {
buf: &mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
Expand Down

0 comments on commit 7c35e4c

Please sign in to comment.