Skip to content

Commit

Permalink
refactor(llm): remove llm imports from examples
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 24, 2023
1 parent 7c35e4c commit 7c77c96
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
18 changes: 7 additions & 11 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use llm::{
load_progress_callback_stdout as load_callback, InferenceFeedback, InferenceParameters,
InferenceRequest, InferenceResponse, ModelArchitecture,
};
use std::{convert::Infallible, io::Write, path::Path};

fn main() {
Expand All @@ -11,7 +7,7 @@ fn main() {
std::process::exit(1);
}

let model_architecture: ModelArchitecture = raw_args[0].parse().unwrap();
let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let prompt = raw_args
.get(2)
Expand All @@ -26,7 +22,7 @@ fn main() {
model_path,
Default::default(),
overrides,
load_callback,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
Expand All @@ -42,22 +38,22 @@ fn main() {
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rand::thread_rng(),
&InferenceRequest {
&llm::InferenceRequest {
prompt: prompt.into(),
parameters: &InferenceParameters::default(),
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
// OutputRequest
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
_ => Ok(llm::InferenceFeedback::Continue),
},
);

Expand Down
37 changes: 16 additions & 21 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use llm::{
load_progress_callback_stdout, InferenceFeedback, InferenceParameters, InferenceRequest,
InferenceResponse, InferenceStats, ModelArchitecture,
};
use rustyline::error::ReadlineError;
use std::{convert::Infallible, io::Write, path::Path};

Expand All @@ -12,7 +8,7 @@ fn main() {
std::process::exit(1);
}

let model_architecture: ModelArchitecture = raw_args[0].parse().unwrap();
let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let overrides = raw_args.get(2).map(|s| serde_json::from_str(s).unwrap());

Expand All @@ -21,7 +17,7 @@ fn main() {
model_path,
Default::default(),
overrides,
load_progress_callback_stdout,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
Expand All @@ -38,7 +34,7 @@ fn main() {
{character_name}: Paris is the capital of France."
);

let inference_parameters = InferenceParameters::default();
let inference_parameters = llm::InferenceParameters::default();

session
.feed_prompt(
Expand All @@ -47,18 +43,17 @@ fn main() {
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
llm::feed_prompt_callback(|resp| match resp {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print_token(t)
}
_ => Ok(InferenceFeedback::Continue),
llm::InferenceResponse::PromptToken(t)
| llm::InferenceResponse::InferredToken(t) => print_token(t),
_ => Ok(llm::InferenceFeedback::Continue),
}),
)
.expect("Failed to ingest initial prompt.");

let mut rl = rustyline::DefaultEditor::new().expect("Failed to create input reader");

let mut rng = rand::thread_rng();
let mut res = InferenceStats::default();
let mut res = llm::InferenceStats::default();
let mut buf = String::new();

loop {
Expand All @@ -71,7 +66,7 @@ fn main() {
.infer(
model.as_ref(),
&mut rng,
&InferenceRequest {
&llm::InferenceRequest {
prompt: format!("{user_name}: {line}\n{character_name}:")
.as_str()
.into(),
Expand Down Expand Up @@ -106,17 +101,17 @@ fn main() {
fn inference_callback(
stop_sequence: String,
buf: &mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ {
) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
llm::InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(InferenceFeedback::Halt);
return Ok(llm::InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(InferenceFeedback::Continue);
return Ok(llm::InferenceFeedback::Continue);
}

if buf.is_empty() {
Expand All @@ -125,14 +120,14 @@ fn inference_callback(
print_token(reverse_buf)
}
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
_ => Ok(llm::InferenceFeedback::Continue),
}
}

fn print_token(t: String) -> Result<InferenceFeedback, Infallible> {
fn print_token(t: String) -> Result<llm::InferenceFeedback, Infallible> {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
Ok(llm::InferenceFeedback::Continue)
}

0 comments on commit 7c77c96

Please sign in to comment.