Skip to content

Commit

Permalink
revise example arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 29, 2023
1 parent c5ce2d1 commit bf30f84
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
16 changes: 9 additions & 7 deletions crates/llm/examples/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use clap::Parser;

#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
model_architecture: llm::ModelArchitecture,
model_path: PathBuf,
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,
#[arg(long, short = 'r')]
Expand All @@ -32,8 +32,8 @@ fn main() {
let args = Args::parse();

let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let query = args
.query
.as_deref()
Expand All @@ -55,13 +55,15 @@ fn main() {
lora_adapters: None,
};
let model = llm::load_dynamic(
architecture,
&path,
model_architecture,
&model_path,
vocabulary_source,
model_params,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
let inference_parameters = llm::InferenceParameters::default();

// Generate embeddings for query and comparands
Expand Down
16 changes: 9 additions & 7 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{convert::Infallible, io::Write, path::PathBuf};

#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
model_architecture: llm::ModelArchitecture,
model_path: PathBuf,
#[arg(long, short = 'p')]
prompt: Option<String>,
#[arg(long, short = 'v')]
Expand All @@ -29,8 +29,8 @@ fn main() {
let args = Args::parse();

let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let prompt = args
.prompt
.as_deref()
Expand All @@ -39,13 +39,15 @@ fn main() {
let now = std::time::Instant::now();

let model = llm::load_dynamic(
architecture,
&path,
model_architecture,
&model_path,
vocabulary_source,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});

println!(
"Model fully loaded! Elapsed: {}ms",
Expand Down
16 changes: 9 additions & 7 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::{convert::Infallible, io::Write, path::PathBuf};

#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
model_architecture: llm::ModelArchitecture,
model_path: PathBuf,
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,
#[arg(long, short = 'r')]
Expand All @@ -28,16 +28,18 @@ fn main() {
let args = Args::parse();

let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let model = llm::load_dynamic(
architecture,
&path,
model_architecture,
&model_path,
vocabulary_source,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});

let mut session = model.start_session(Default::default());

Expand Down

0 comments on commit bf30f84

Please sign in to comment.