-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Local LLM Backend #116
Local LLM Backend #116
Changes from 4 commits
7601079
837f639
8e3e5e8
fdac5d8
d98e8b5
fd4d436
b45a2c3
bcdc265
ccf9051
504e092
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[package] | ||
name = "llm-chain-local" | ||
version = "0.9.0" | ||
edition = "2021" | ||
description = "Use `llm-chain` with a local [`llm`](https://github.com/rustformers/llm) backend." | ||
license = "MIT" | ||
keywords = ["llm", "langchain", "ggml", "chain"] | ||
categories = ["science"] | ||
authors = ["Dan Forbes <dan@danforbes.dev>"] | ||
repository = "https://github.com/sobelio/llm-chain/" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
async-trait = "0.1.68" | ||
llm = { git = "https://github.com/danforbes/llama-rs", branch = "dfo/chore/llm-chain" } | ||
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false } | ||
rand = "0.8.5" | ||
serde = { version = "1.0.160", features = ["derive"] } | ||
thiserror = "1.0.40" | ||
|
||
[dev-dependencies] | ||
tokio = { version = "1.28.0", features = ["macros", "rt"] } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
use std::{env::args, error::Error}; | ||
|
||
use llm_chain::{traits::Executor, prompt::Data}; | ||
use llm_chain_local::{Executor as LocalExecutor, options::PerExecutor}; | ||
|
||
extern crate llm_chain_local; | ||
|
||
/// This example demonstrates how to use the llm-chain-local crate to generate text using a model. | ||
/// | ||
/// Usage: cargo run --release --package llm-chain-local --example simple model_type path/to/model | ||
/// | ||
/// For example, if the model is a LLaMA-type model located at "/models/llama" | ||
/// cargo run --release --package llm-chain-local --example simple llama /models/llama | ||
/// | ||
/// An optional third argument can be used to customize the prompt passed to the model. | ||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() -> Result<(), Box<dyn Error>> { | ||
let raw_args: Vec<String> = args().collect(); | ||
let args = match &raw_args.len() { | ||
3 => (raw_args[1].as_str(), raw_args[2].as_str(), "Rust is a cool programming language because"), | ||
4 => (raw_args[1].as_str(), raw_args[2].as_str(), raw_args[3].as_str()), | ||
_ => panic!("Usage: cargo run --release --example inference <model type> <path to model> <optional prompt>") | ||
}; | ||
|
||
let model_type = args.0; | ||
let model_path = args.1; | ||
let prompt = args.2; | ||
|
||
let exec_opts = PerExecutor { | ||
model_path: Some(String::from(model_path)), | ||
model_type: Some(String::from(model_type)), | ||
}; | ||
|
||
let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?; | ||
let res = exec.execute(None, &Data::Text(String::from(prompt))).await?; | ||
|
||
println!("{}", res); | ||
Ok(()) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
use std::convert::Infallible; | ||
use std::env::var; | ||
use std::path::Path; | ||
|
||
use crate::options::{PerExecutor, PerInvocation}; | ||
use crate::output::Output; | ||
use crate::LocalLlmTextSplitter; | ||
|
||
use async_trait::async_trait; | ||
use llm::models::{Bloom, Gpt2, GptJ, Llama, NeoX}; | ||
use llm::{ | ||
load_progress_callback_stdout, KnownModel, LoadError, Model, ModelParameters, TokenId, | ||
TokenUtf8Buffer, | ||
}; | ||
use llm_chain::prompt::Prompt; | ||
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError}; | ||
use llm_chain::traits::{ExecutorCreationError, ExecutorError}; | ||
use thiserror::Error; | ||
|
||
/// Executor is responsible for running the LLM and managing its context. | ||
pub struct Executor { | ||
llm: Box<dyn Model>, | ||
} | ||
|
||
impl Executor { | ||
pub(crate) fn get_llm(&self) -> &dyn Model { | ||
self.llm.as_ref() | ||
} | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum Error { | ||
#[error("unable to tokenize prompt")] | ||
PromptTokensError(PromptTokensError), | ||
#[error("unable to create executor: {0}")] | ||
InnerError(#[from] Box<dyn std::error::Error>), | ||
} | ||
|
||
impl ExecutorError for Error {} | ||
|
||
#[async_trait] | ||
impl llm_chain::traits::Executor for Executor { | ||
type PerInvocationOptions = PerInvocation; | ||
type PerExecutorOptions = PerExecutor; | ||
type Output = Output; | ||
type Error = Error; | ||
type Token = i32; | ||
type StepTokenizer<'a> = LocalLlmTokenizer<'a>; | ||
type TextSplitter<'a> = LocalLlmTextSplitter<'a>; | ||
|
||
fn new_with_options( | ||
options: Option<Self::PerExecutorOptions>, | ||
invocation_options: Option<Self::PerInvocationOptions>, | ||
) -> Result<Self, ExecutorCreationError> { | ||
let model_type = options | ||
.as_ref() | ||
.and_then(|x| x.model_type.clone()) | ||
.or_else(|| var("LLM_MODEL_TYPE").ok()) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_type, ensure to provide the parameter or set `LLM_MODEL_TYPE` environment variable ".to_string(), | ||
))?; | ||
let model_path = options | ||
.as_ref() | ||
.and_then(|x| x.model_path.clone()) | ||
.or_else(|| var("LLM_MODEL_PATH").ok()) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(), | ||
))?; | ||
|
||
let params = invocation_options.unwrap_or_default().into(); | ||
let llm: Box<dyn Model> = match model_type.as_str() { | ||
"bloom" => load::<Bloom>(&model_path, params), | ||
"gpt2" => load::<Gpt2>(&model_path, params), | ||
"gptj" => load::<GptJ>(&model_path, params), | ||
"llama" => load::<Llama>(&model_path, params), | ||
"neox" => load::<NeoX>(&model_path, params), | ||
m => Err(LoadError::InvariantBroken { | ||
path: None, | ||
invariant: format!("Unsupported model type {m}"), | ||
}), | ||
} | ||
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?; | ||
|
||
Ok(Executor { llm }) | ||
} | ||
|
||
async fn execute( | ||
&self, | ||
// TODO: call infer_with_params if this is present | ||
_: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<Self::Output, Self::Error> { | ||
let session = &mut self.llm.start_session(Default::default()); | ||
let mut output = String::new(); | ||
session | ||
.infer::<Infallible>( | ||
self.llm.as_ref(), | ||
prompt.to_text().as_str(), | ||
// EvaluateOutputRequest | ||
&mut Default::default(), | ||
&mut rand::thread_rng(), | ||
|t| { | ||
output.push_str(t); | ||
|
||
Ok(()) | ||
}, | ||
) | ||
.map_err(|e| Error::InnerError(Box::new(e)))?; | ||
|
||
Ok(output.into()) | ||
} | ||
|
||
fn tokens_used( | ||
&self, | ||
options: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<TokenCount, PromptTokensError> { | ||
let tokenizer = self.get_tokenizer(options)?; | ||
let input = prompt.to_text(); | ||
|
||
let tokens_used = tokenizer | ||
.tokenize_str(&input) | ||
.map_err(|_e| PromptTokensError::UnableToCompute)? | ||
.len() as i32; | ||
let max_tokens = self.max_tokens_allowed(options); | ||
Ok(TokenCount::new(max_tokens, tokens_used)) | ||
} | ||
|
||
fn max_tokens_allowed(&self, _: Option<&Self::PerInvocationOptions>) -> i32 { | ||
self.llm.n_context_tokens().try_into().unwrap_or(2048) | ||
} | ||
|
||
fn get_tokenizer( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::StepTokenizer<'_>, TokenizerError> { | ||
Ok(LocalLlmTokenizer::new(self)) | ||
} | ||
|
||
fn get_text_splitter( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::TextSplitter<'_>, Self::Error> { | ||
Ok(LocalLlmTextSplitter::new(self)) | ||
} | ||
} | ||
|
||
pub struct LocalLlmTokenizer<'a> { | ||
llm: &'a dyn Model, | ||
} | ||
|
||
impl<'a> LocalLlmTokenizer<'a> { | ||
pub fn new(executor: &'a Executor) -> Self { | ||
LocalLlmTokenizer { | ||
llm: executor.llm.as_ref(), | ||
} | ||
} | ||
} | ||
|
||
impl Tokenizer<llm::TokenId> for LocalLlmTokenizer<'_> { | ||
fn tokenize_str(&self, doc: &str) -> Result<Vec<llm::TokenId>, TokenizerError> { | ||
match &self.llm.vocabulary().tokenize(doc, false) { | ||
Ok(tokens) => Ok(tokens.into_iter().map(|t| t.1).collect()), | ||
Err(_) => Err(TokenizerError::TokenizationError), | ||
} | ||
} | ||
|
||
fn to_string(&self, tokens: Vec<TokenId>) -> Result<String, TokenizerError> { | ||
let mut res = String::new(); | ||
let mut token_utf8_buf = TokenUtf8Buffer::new(); | ||
for token_id in tokens { | ||
// Buffer the token until it's valid UTF-8, then call the callback. | ||
if let Some(tokens) = | ||
token_utf8_buf.push(self.llm.vocabulary().token(token_id as usize)) | ||
{ | ||
res.push_str(&tokens) | ||
} | ||
} | ||
|
||
Ok(res) | ||
} | ||
} | ||
|
||
fn load<M: KnownModel + 'static>( | ||
model_path: &str, | ||
params: ModelParameters, | ||
) -> Result<Box<dyn Model>, LoadError> { | ||
let model = llm::load::<M>(Path::new(model_path), params, load_progress_callback_stdout)?; | ||
|
||
Ok(Box::new(model)) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod executor; | ||
pub mod options; | ||
mod output; | ||
mod text_splitter; | ||
|
||
pub use executor::Executor; | ||
pub use text_splitter::LocalLlmTextSplitter; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
use llm::{InferenceParameters, InferenceWithPromptParameters, ModelParameters}; | ||
use llm_chain::traits::Options; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Default, Serialize, Deserialize)] | ||
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation. | ||
pub struct PerInvocation { | ||
pub n_threads: Option<usize>, | ||
pub n_tok_predict: Option<usize>, | ||
pub top_k: Option<usize>, | ||
pub top_p: Option<f32>, | ||
pub temp: Option<f32>, | ||
pub repeat_penalty: Option<f32>, | ||
pub stop_sequence: Option<String>, | ||
} | ||
|
||
impl Options for PerInvocation {} | ||
|
||
impl Into<ModelParameters> for PerInvocation { | ||
fn into(self) -> ModelParameters { | ||
let inference_params = InferenceParameters { | ||
n_threads: self.n_threads.unwrap_or(4), | ||
n_batch: 8, | ||
top_k: self.top_k.unwrap_or(40), | ||
top_p: self.top_p.unwrap_or(0.95), | ||
repeat_penalty: self.temp.unwrap_or(1.3), | ||
temperature: self.temp.unwrap_or(0.8), | ||
bias_tokens: Default::default(), | ||
}; | ||
|
||
let prompt_params = InferenceWithPromptParameters { | ||
play_back_previous_tokens: false, | ||
maximum_token_count: None, | ||
}; | ||
|
||
ModelParameters { | ||
prefer_mmap: true, | ||
n_context_tokens: 2048, | ||
inference_params, | ||
inference_prompt_params: prompt_params, | ||
} | ||
} | ||
} | ||
|
||
/// `PerExecutor` represents a collection of configuration parameters for the executor of the LLM. | ||
/// It contains optional fields for the model path and context parameters. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use llm_chain_local::PerExecutor; | ||
/// let executor = PerExecutor::new().with_model_path("path/to/model"); | ||
/// ``` | ||
#[derive(Debug, Clone, Serialize, Deserialize, Default)] | ||
pub struct PerExecutor { | ||
/// Optional path to the LLM. | ||
pub model_path: Option<String>, | ||
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM. | ||
pub model_type: Option<String>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No ideas to have an enum and a fallback? To let know what models are supported? I mean something like: enum ModelType {
LLama,
GPTNeoX,
Custom(String)
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model architecture enum isn't currently serializable (although I don't see why it can't be), so I'm leaving this a string for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah alright! I see! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully this will be changed soon rustformers/llm#200 |
||
} | ||
|
||
impl PerExecutor { | ||
/// Creates a new `PerExecutor` instance with default values. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `PerExecutor` instance with default values for the model path and context parameters. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
|
||
/// Sets the model path for the current `PerExecutor` instance. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `model_path` - The path to the LLM. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A new `PerExecutor` instance with the updated model path. | ||
pub fn with_model_path(mut self, model_path: &str) -> Self { | ||
self.model_path = Some(model_path.to_string()); | ||
self | ||
} | ||
} | ||
|
||
impl Options for PerExecutor {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a
Path
instead perhaps?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without trying too hard, a first attempt fails because
Path
cannot beSized
, which appears to be required forOption
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes. How about
PathBuf
. Should beSized
.