-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Local LLM Backend #116
Merged
+833
−4
Merged
Local LLM Backend #116
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
7601079
Local LLM Backend
danforbes 837f639
Import more, qualify less
danforbes 8e3e5e8
Remove callback from Executor
danforbes fdac5d8
Don't publish load function
danforbes d98e8b5
Use latest `llm` with better model loading
danforbes fd4d436
Merge remote-tracking branch 'origin/main' into dfo/feat/local-llm
danforbes b45a2c3
Implement naive `answer_prefix`
danforbes bcdc265
Merge remote-tracking branch 'origin/main' into dfo/feat/local-llm
danforbes ccf9051
Update with latest llm release
danforbes 504e092
Fix CI
danforbes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[package] | ||
name = "llm-chain-local" | ||
version = "0.9.0" | ||
edition = "2021" | ||
description = "Use `llm-chain` with a local [`llm`](https://github.com/rustformers/llm) backend." | ||
license = "MIT" | ||
keywords = ["llm", "langchain", "ggml", "chain"] | ||
categories = ["science"] | ||
authors = ["Dan Forbes <dan@danforbes.dev>"] | ||
repository = "https://github.com/sobelio/llm-chain/" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
async-trait = "0.1.68" | ||
llm = "0.1.1" | ||
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false } | ||
rand = "0.8.5" | ||
serde = { version = "1.0.160", features = ["derive"] } | ||
thiserror = "1.0.40" | ||
|
||
[dev-dependencies] | ||
tokio = { version = "1.28.0", features = ["macros", "rt"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
use std::{env::args, error::Error, path::PathBuf}; | ||
|
||
use llm_chain::{traits::Executor, prompt::Data}; | ||
use llm_chain_local::{Executor as LocalExecutor, options::PerExecutor}; | ||
|
||
extern crate llm_chain_local; | ||
|
||
/// This example demonstrates how to use the llm-chain-local crate to generate text using a model. | ||
/// | ||
/// Usage: cargo run --release --package llm-chain-local --example simple model_type path/to/model | ||
/// | ||
/// For example, if the model is a LLaMA-type model located at "/models/llama" | ||
/// cargo run --release --package llm-chain-local --example simple llama /models/llama | ||
/// | ||
/// An optional third argument can be used to customize the prompt passed to the model. | ||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() -> Result<(), Box<dyn Error>> { | ||
let raw_args: Vec<String> = args().collect(); | ||
let args = match &raw_args.len() { | ||
3 => (raw_args[1].as_str(), raw_args[2].as_str(), "Rust is a cool programming language because"), | ||
4 => (raw_args[1].as_str(), raw_args[2].as_str(), raw_args[3].as_str()), | ||
_ => panic!("Usage: cargo run --release --example inference <model type> <path to model> <optional prompt>") | ||
}; | ||
|
||
let model_type = args.0; | ||
let model_path = args.1; | ||
let prompt = args.2; | ||
|
||
let exec_opts = PerExecutor { | ||
model_path: Some(PathBuf::from(model_path)), | ||
model_type: Some(String::from(model_type)), | ||
}; | ||
|
||
let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?; | ||
let res = exec.execute(None, &Data::Text(String::from(prompt))).await?; | ||
|
||
println!("{}", res); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
use std::convert::Infallible; | ||
use std::env::var; | ||
use std::path::{Path, PathBuf}; | ||
use std::str::FromStr; | ||
|
||
use crate::options::{PerExecutor, PerInvocation}; | ||
use crate::output::Output; | ||
use crate::LocalLlmTextSplitter; | ||
|
||
use async_trait::async_trait; | ||
use llm::{ | ||
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture, | ||
TokenBias, TokenId, TokenUtf8Buffer, | ||
}; | ||
use llm_chain::prompt::Prompt; | ||
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError}; | ||
use llm_chain::traits::{ExecutorCreationError, ExecutorError}; | ||
use thiserror::Error; | ||
|
||
/// Executor is responsible for running the LLM and managing its context. | ||
pub struct Executor { | ||
llm: Box<dyn Model>, | ||
} | ||
|
||
impl Executor { | ||
pub(crate) fn get_llm(&self) -> &dyn Model { | ||
self.llm.as_ref() | ||
} | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum Error { | ||
#[error("unable to tokenize prompt")] | ||
PromptTokensError(PromptTokensError), | ||
#[error("unable to create executor: {0}")] | ||
InnerError(#[from] Box<dyn std::error::Error>), | ||
} | ||
|
||
impl ExecutorError for Error {} | ||
|
||
#[async_trait] | ||
impl llm_chain::traits::Executor for Executor { | ||
type PerInvocationOptions = PerInvocation; | ||
type PerExecutorOptions = PerExecutor; | ||
type Output = Output; | ||
type Error = Error; | ||
type Token = i32; | ||
type StepTokenizer<'a> = LocalLlmTokenizer<'a>; | ||
type TextSplitter<'a> = LocalLlmTextSplitter<'a>; | ||
|
||
fn new_with_options( | ||
options: Option<Self::PerExecutorOptions>, | ||
invocation_options: Option<Self::PerInvocationOptions>, | ||
) -> Result<Self, ExecutorCreationError> { | ||
let model_type = options | ||
.as_ref() | ||
.and_then(|x| x.model_type.clone()) | ||
.or_else(|| var("LLM_MODEL_TYPE").ok()) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_type, ensure to provide the parameter or set `LLM_MODEL_TYPE` environment variable ".to_string(), | ||
))?; | ||
let model_path = options | ||
.as_ref() | ||
.and_then(|x| x.model_path.clone()) | ||
.or_else(|| var("LLM_MODEL_PATH").ok().map(|s| PathBuf::from(s))) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(), | ||
))?; | ||
|
||
let model_arch = model_type | ||
.parse::<ModelArchitecture>() | ||
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?; | ||
let params = invocation_options.unwrap_or_default().into(); | ||
let llm: Box<dyn Model> = llm::load_dynamic( | ||
model_arch, | ||
Path::new(&model_path), | ||
params, | ||
load_progress_callback_stdout, | ||
) | ||
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?; | ||
|
||
Ok(Executor { llm }) | ||
} | ||
|
||
async fn execute( | ||
&self, | ||
options: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<Self::Output, Self::Error> { | ||
let parameters = match options { | ||
None => Default::default(), | ||
Some(opts) => InferenceParameters { | ||
n_threads: opts.n_threads.unwrap_or(4), | ||
n_batch: opts.n_batch.unwrap_or(8), | ||
top_k: opts.top_k.unwrap_or(40), | ||
top_p: opts.top_p.unwrap_or(0.95), | ||
temperature: opts.temp.unwrap_or(0.8), | ||
bias_tokens: { | ||
match &opts.bias_tokens { | ||
None => Default::default(), | ||
Some(str) => TokenBias::from_str(str.as_str()) | ||
.map_err(|e| Error::InnerError(e.into()))?, | ||
} | ||
}, | ||
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3), | ||
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512), | ||
}, | ||
}; | ||
let session = &mut self.llm.start_session(Default::default()); | ||
let mut output = String::new(); | ||
session | ||
.infer::<Infallible>( | ||
self.llm.as_ref(), | ||
&mut rand::thread_rng(), | ||
&InferenceRequest { | ||
prompt: prompt.to_text().as_str(), | ||
parameters: Some(¶meters), | ||
// playback_previous_tokens | ||
// maximum_token_count | ||
..Default::default() | ||
}, | ||
// OutputRequest | ||
&mut Default::default(), | ||
|t| { | ||
output.push_str(t); | ||
|
||
Ok(()) | ||
}, | ||
) | ||
.map_err(|e| Error::InnerError(Box::new(e)))?; | ||
|
||
Ok(output.into()) | ||
} | ||
|
||
fn tokens_used( | ||
&self, | ||
options: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<TokenCount, PromptTokensError> { | ||
let tokenizer = self.get_tokenizer(options)?; | ||
let input = prompt.to_text(); | ||
|
||
let tokens_used = tokenizer | ||
.tokenize_str(&input) | ||
.map_err(|_e| PromptTokensError::UnableToCompute)? | ||
.len() as i32; | ||
let max_tokens = self.max_tokens_allowed(options); | ||
Ok(TokenCount::new(max_tokens, tokens_used)) | ||
} | ||
|
||
fn max_tokens_allowed(&self, _: Option<&Self::PerInvocationOptions>) -> i32 { | ||
self.llm.n_context_tokens().try_into().unwrap_or(2048) | ||
} | ||
|
||
fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> { | ||
None | ||
} | ||
|
||
fn get_tokenizer( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::StepTokenizer<'_>, TokenizerError> { | ||
Ok(LocalLlmTokenizer::new(self)) | ||
} | ||
|
||
fn get_text_splitter( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::TextSplitter<'_>, Self::Error> { | ||
Ok(LocalLlmTextSplitter::new(self)) | ||
} | ||
} | ||
|
||
pub struct LocalLlmTokenizer<'a> { | ||
llm: &'a dyn Model, | ||
} | ||
|
||
impl<'a> LocalLlmTokenizer<'a> { | ||
pub fn new(executor: &'a Executor) -> Self { | ||
LocalLlmTokenizer { | ||
llm: executor.llm.as_ref(), | ||
} | ||
} | ||
} | ||
|
||
impl Tokenizer<llm::TokenId> for LocalLlmTokenizer<'_> { | ||
fn tokenize_str(&self, doc: &str) -> Result<Vec<llm::TokenId>, TokenizerError> { | ||
match &self.llm.vocabulary().tokenize(doc, false) { | ||
Ok(tokens) => Ok(tokens.into_iter().map(|t| t.1).collect()), | ||
Err(_) => Err(TokenizerError::TokenizationError), | ||
} | ||
} | ||
|
||
fn to_string(&self, tokens: Vec<TokenId>) -> Result<String, TokenizerError> { | ||
let mut res = String::new(); | ||
let mut token_utf8_buf = TokenUtf8Buffer::new(); | ||
for token_id in tokens { | ||
// Buffer the token until it's valid UTF-8, then call the callback. | ||
if let Some(tokens) = | ||
token_utf8_buf.push(self.llm.vocabulary().token(token_id as usize)) | ||
{ | ||
res.push_str(&tokens) | ||
} | ||
} | ||
|
||
Ok(res) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod executor; | ||
pub mod options; | ||
mod output; | ||
mod text_splitter; | ||
|
||
pub use executor::Executor; | ||
pub use text_splitter::LocalLlmTextSplitter; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use std::path::PathBuf; | ||
|
||
use llm::{InferenceParameters, ModelParameters}; | ||
use llm_chain::traits::Options; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Default, Serialize, Deserialize)] | ||
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation. | ||
pub struct PerInvocation { | ||
pub n_threads: Option<usize>, | ||
pub n_batch: Option<usize>, | ||
pub n_tok_predict: Option<usize>, | ||
pub top_k: Option<usize>, | ||
pub top_p: Option<f32>, | ||
pub temp: Option<f32>, | ||
/// A comma separated list of token biases. The list should be in the format | ||
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a | ||
/// floating point number. | ||
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 | ||
/// (start of document) and 2 (end of document) to -1.0 which effectively | ||
/// disables the model from generating responses containing those token IDs. | ||
pub bias_tokens: Option<String>, | ||
pub repeat_penalty: Option<f32>, | ||
pub repeat_penalty_last_n: Option<usize>, | ||
} | ||
|
||
impl Options for PerInvocation {} | ||
|
||
impl Into<ModelParameters> for PerInvocation { | ||
fn into(self) -> ModelParameters { | ||
let inference_parameters = InferenceParameters { | ||
n_threads: self.n_threads.unwrap_or(4), | ||
n_batch: self.n_batch.unwrap_or(8), | ||
top_k: self.top_k.unwrap_or(40), | ||
top_p: self.top_p.unwrap_or(0.95), | ||
repeat_penalty: self.temp.unwrap_or(1.3), | ||
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512), | ||
temperature: self.temp.unwrap_or(0.8), | ||
bias_tokens: Default::default(), | ||
}; | ||
|
||
ModelParameters { | ||
prefer_mmap: true, | ||
n_context_tokens: 2048, | ||
inference_parameters, | ||
} | ||
} | ||
} | ||
|
||
/// `PerExecutor` represents a collection of configuration parameters for the executor of the LLM. | ||
/// It contains optional fields for the model path and context parameters. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use llm_chain_local::PerExecutor; | ||
/// let executor = PerExecutor::new().with_model_path("path/to/model"); | ||
/// ``` | ||
#[derive(Debug, Clone, Serialize, Deserialize, Default)] | ||
pub struct PerExecutor { | ||
/// Optional path to the LLM. | ||
pub model_path: Option<PathBuf>, | ||
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM. | ||
pub model_type: Option<String>, | ||
} | ||
|
||
impl PerExecutor { | ||
/// Creates a new `PerExecutor` instance with default values. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `PerExecutor` instance with default values for the model path and context parameters. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
|
||
/// Sets the model path for the current `PerExecutor` instance. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `model_path` - The path to the LLM. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A new `PerExecutor` instance with the updated model path. | ||
pub fn with_model_path(mut self, model_path: &str) -> Self { | ||
self.model_path = Some(PathBuf::from(model_path)); | ||
self | ||
} | ||
} | ||
|
||
impl Options for PerExecutor {} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No ideas to have an enum and a fallback? To let know what models are supported?
I mean something like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model architecture enum isn't currently serializable (although I don't see why it can't be), so I'm leaving this a string for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah alright! I see!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully this will be changed soon rustformers/llm#200