Skip to content

Commit

Permalink
refactor(examples): add vocabulary via clap
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 29, 2023
1 parent faecd36 commit d795bc8
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 71 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ bytemuck = "1.13.1"
bytesize = "1.1"
log = "0.4"
rand = "0.8.5"
thiserror = "1.0"
anyhow = "1.0"

rustyline = { version = "11.0.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
thiserror = "1.0"
anyhow = "1.0"
clap = { version = "4.1.8", features = ["derive"] }

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ log = { workspace = true }
rand = { workspace = true }
rustyline = { workspace = true }
spinoff = { workspace = true }
clap = { workspace = true }

bincode = "1.3.3"
env_logger = "0.10.0"
num_cpus = "1.15.0"

clap = { version = "4.1.8", features = ["derive"] }
color-eyre = { version = "0.6.2", default-features = false }
zstd = { version = "0.12", default-features = false }
1 change: 1 addition & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rand = { workspace = true }
rustyline = { workspace = true }
spinoff = { workspace = true }
serde_json = { workspace = true }
clap = { workspace = true }

[features]
default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]
Expand Down
84 changes: 52 additions & 32 deletions crates/llm/examples/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,52 @@
use std::path::Path;
use std::path::PathBuf;

fn main() {
// Get arguments from command line
let raw_args: Vec<String> = std::env::args().skip(1).collect();
if raw_args.len() < 2 {
println!("Usage: cargo run --release --example embeddings <model_architecture> <model_path> [query] [comma-separated comparands] [overrides, json]");
std::process::exit(1);
use clap::Parser;

#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,
#[arg(long, short = 'r')]
pub vocabulary_repository: Option<String>,
#[arg(long, short = 'q')]
pub query: Option<String>,
#[arg(long, short = 'c')]
pub comparands: Vec<String>,
}
impl Args {
pub fn to_vocabulary_source(&self) -> llm::VocabularySource {
match (&self.vocabulary_path, &self.vocabulary_repository) {
(Some(_), Some(_)) => {
panic!("Cannot specify both --vocabulary-path and --vocabulary-repository");
}
(Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()),
(None, None) => llm::VocabularySource::Model,
}
}
}

let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let query = raw_args
.get(2)
.map(|s| s.as_str())
fn main() {
let args = Args::parse();

let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let query = args
.query
.as_deref()
.unwrap_or("My favourite animal is the dog");
let comparands = raw_args
.get(3)
.map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>())
.unwrap_or_else(|| {
vec![
"My favourite animal is the dog",
"I have just adopted a cute dog",
"My favourite animal is the cat",
]
});
let overrides = raw_args.get(4).map(|s| serde_json::from_str(s).unwrap());
let comparands = if !args.comparands.is_empty() {
args.comparands
} else {
vec![
"My favourite animal is the dog".to_string(),
"I have just adopted a cute dog".to_string(),
"My favourite animal is the cat".to_string(),
]
};

// Load model
let model_params = llm::ModelParameters {
Expand All @@ -33,25 +55,23 @@ fn main() {
lora_adapters: None,
};
let model = llm::load_dynamic(
model_architecture,
model_path,
llm::VocabularySource::Model,
architecture,
&path,
vocabulary_source,
model_params,
overrides,
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));
let inference_parameters = llm::InferenceParameters::default();

// Generate embeddings for query and comparands
let query_embeddings = get_embeddings(model.as_ref(), &inference_parameters, query);
let comparand_embeddings: Vec<(String, Vec<f32>)> = comparands
.iter()
.map(|&text| {
.map(|text| {
(
text.to_owned(),
text.clone(),
get_embeddings(model.as_ref(), &inference_parameters, text),
)
})
Expand Down
57 changes: 38 additions & 19 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
use std::{convert::Infallible, io::Write, path::Path};
use clap::Parser;
use std::{convert::Infallible, io::Write, path::PathBuf};

fn main() {
let raw_args: Vec<String> = std::env::args().skip(1).collect();
if raw_args.len() < 2 {
println!("Usage: cargo run --release --example inference <model_architecture> <model_path> [prompt] [overrides, json]");
std::process::exit(1);
#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
#[arg(long, short = 'p')]
prompt: Option<String>,
#[arg(long, short = 'v')]
vocabulary_path: Option<PathBuf>,
#[arg(long, short = 'r')]
vocabulary_repository: Option<String>,
}
impl Args {
pub fn to_vocabulary_source(&self) -> llm::VocabularySource {
match (&self.vocabulary_path, &self.vocabulary_repository) {
(Some(_), Some(_)) => {
panic!("Cannot specify both --vocabulary-path and --vocabulary-repository");
}
(Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()),
(None, None) => llm::VocabularySource::Model,
}
}
}

fn main() {
let args = Args::parse();

let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let prompt = raw_args
.get(2)
.map(|s| s.as_str())
let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let prompt = args
.prompt
.as_deref()
.unwrap_or("Rust is a cool programming language because");
let overrides = raw_args.get(3).map(|s| serde_json::from_str(s).unwrap());

let now = std::time::Instant::now();

let model = llm::load_dynamic(
model_architecture,
model_path,
llm::VocabularySource::Model,
architecture,
&path,
vocabulary_source,
Default::default(),
overrides,
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));

println!(
"Model fully loaded! Elapsed: {}ms",
Expand Down
50 changes: 33 additions & 17 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
use clap::Parser;
use rustyline::error::ReadlineError;
use std::{convert::Infallible, io::Write, path::Path};

fn main() {
let raw_args: Vec<String> = std::env::args().skip(1).collect();
if raw_args.len() < 2 {
println!("Usage: cargo run --release --example vicuna-chat <model_architecture> <model_path> [overrides, json]");
std::process::exit(1);
use std::{convert::Infallible, io::Write, path::PathBuf};

#[derive(Parser)]
struct Args {
architecture: String,
path: PathBuf,
#[arg(long, short = 'v')]
pub vocabulary_path: Option<PathBuf>,
#[arg(long, short = 'r')]
pub vocabulary_repository: Option<String>,
}
impl Args {
pub fn to_vocabulary_source(&self) -> llm::VocabularySource {
match (&self.vocabulary_path, &self.vocabulary_repository) {
(Some(_), Some(_)) => {
panic!("Cannot specify both --vocabulary-path and --vocabulary-repository");
}
(Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()),
(None, None) => llm::VocabularySource::Model,
}
}
}

let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let overrides = raw_args.get(2).map(|s| serde_json::from_str(s).unwrap());
fn main() {
let args = Args::parse();

let vocabulary_source = args.to_vocabulary_source();
let architecture = args.architecture.parse().unwrap();
let path = args.path;
let model = llm::load_dynamic(
model_architecture,
model_path,
llm::VocabularySource::Model,
architecture,
&path,
vocabulary_source,
Default::default(),
overrides,
None,
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}"));

let mut session = model.start_session(Default::default());

Expand Down

0 comments on commit d795bc8

Please sign in to comment.