-
Notifications
You must be signed in to change notification settings - Fork 134
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Local LLM Backend * Import more, qualify less * Remove callback from Executor * Don't publish load function * Use latest `llm` with better model loading * Implement naive `answer_prefix` * Update with latest llm release * Fix CI
- Loading branch information
Showing
9 changed files
with
833 additions
and
4 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[package] | ||
name = "llm-chain-local" | ||
version = "0.9.0" | ||
edition = "2021" | ||
description = "Use `llm-chain` with a local [`llm`](https://github.com/rustformers/llm) backend." | ||
license = "MIT" | ||
keywords = ["llm", "langchain", "ggml", "chain"] | ||
categories = ["science"] | ||
authors = ["Dan Forbes <dan@danforbes.dev>"] | ||
repository = "https://github.com/sobelio/llm-chain/" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
async-trait = "0.1.68" | ||
llm = "0.1.1" | ||
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false } | ||
rand = "0.8.5" | ||
serde = { version = "1.0.160", features = ["derive"] } | ||
thiserror = "1.0.40" | ||
|
||
[dev-dependencies] | ||
tokio = { version = "1.28.0", features = ["macros", "rt"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
use std::{env::args, error::Error, path::PathBuf}; | ||
|
||
use llm_chain::{prompt::Data, traits::Executor}; | ||
use llm_chain_local::{options::PerExecutor, Executor as LocalExecutor}; | ||
|
||
extern crate llm_chain_local; | ||
|
||
/// This example demonstrates how to use the llm-chain-local crate to generate text using a model. | ||
/// | ||
/// Usage: cargo run --release --package llm-chain-local --example simple model_type path/to/model | ||
/// | ||
/// For example, if the model is a LLaMA-type model located at "/models/llama" | ||
/// cargo run --release --package llm-chain-local --example simple llama /models/llama | ||
/// | ||
/// An optional third argument can be used to customize the prompt passed to the model. | ||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() -> Result<(), Box<dyn Error>> { | ||
let raw_args: Vec<String> = args().collect(); | ||
let args = match &raw_args.len() { | ||
3 => (raw_args[1].as_str(), raw_args[2].as_str(), "Rust is a cool programming language because"), | ||
4 => (raw_args[1].as_str(), raw_args[2].as_str(), raw_args[3].as_str()), | ||
_ => panic!("Usage: cargo run --release --example inference <model type> <path to model> <optional prompt>") | ||
}; | ||
|
||
let model_type = args.0; | ||
let model_path = args.1; | ||
let prompt = args.2; | ||
|
||
let exec_opts = PerExecutor { | ||
model_path: Some(PathBuf::from(model_path)), | ||
model_type: Some(String::from(model_type)), | ||
}; | ||
|
||
let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?; | ||
let res = exec | ||
.execute(None, &Data::Text(String::from(prompt))) | ||
.await?; | ||
|
||
println!("{}", res); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
use std::convert::Infallible; | ||
use std::env::var; | ||
use std::path::{Path, PathBuf}; | ||
use std::str::FromStr; | ||
|
||
use crate::options::{PerExecutor, PerInvocation}; | ||
use crate::output::Output; | ||
use crate::LocalLlmTextSplitter; | ||
|
||
use async_trait::async_trait; | ||
use llm::{ | ||
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture, | ||
TokenBias, TokenId, TokenUtf8Buffer, | ||
}; | ||
use llm_chain::prompt::Prompt; | ||
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError}; | ||
use llm_chain::traits::{ExecutorCreationError, ExecutorError}; | ||
use thiserror::Error; | ||
|
||
/// Executor is responsible for running the LLM and managing its context. | ||
pub struct Executor { | ||
llm: Box<dyn Model>, | ||
} | ||
|
||
impl Executor { | ||
pub(crate) fn get_llm(&self) -> &dyn Model { | ||
self.llm.as_ref() | ||
} | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum Error { | ||
#[error("unable to tokenize prompt")] | ||
PromptTokensError(PromptTokensError), | ||
#[error("unable to create executor: {0}")] | ||
InnerError(#[from] Box<dyn std::error::Error>), | ||
} | ||
|
||
impl ExecutorError for Error {} | ||
|
||
#[async_trait] | ||
impl llm_chain::traits::Executor for Executor { | ||
type PerInvocationOptions = PerInvocation; | ||
type PerExecutorOptions = PerExecutor; | ||
type Output = Output; | ||
type Error = Error; | ||
type Token = i32; | ||
type StepTokenizer<'a> = LocalLlmTokenizer<'a>; | ||
type TextSplitter<'a> = LocalLlmTextSplitter<'a>; | ||
|
||
fn new_with_options( | ||
options: Option<Self::PerExecutorOptions>, | ||
invocation_options: Option<Self::PerInvocationOptions>, | ||
) -> Result<Self, ExecutorCreationError> { | ||
let model_type = options | ||
.as_ref() | ||
.and_then(|x| x.model_type.clone()) | ||
.or_else(|| var("LLM_MODEL_TYPE").ok()) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_type, ensure to provide the parameter or set `LLM_MODEL_TYPE` environment variable ".to_string(), | ||
))?; | ||
let model_path = options | ||
.as_ref() | ||
.and_then(|x| x.model_path.clone()) | ||
.or_else(|| var("LLM_MODEL_PATH").ok().map(|s| PathBuf::from(s))) | ||
.ok_or(ExecutorCreationError::FieldRequiredError( | ||
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(), | ||
))?; | ||
|
||
let model_arch = model_type | ||
.parse::<ModelArchitecture>() | ||
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?; | ||
let params = invocation_options.unwrap_or_default().into(); | ||
let llm: Box<dyn Model> = llm::load_dynamic( | ||
model_arch, | ||
Path::new(&model_path), | ||
params, | ||
load_progress_callback_stdout, | ||
) | ||
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?; | ||
|
||
Ok(Executor { llm }) | ||
} | ||
|
||
async fn execute( | ||
&self, | ||
options: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<Self::Output, Self::Error> { | ||
let parameters = match options { | ||
None => Default::default(), | ||
Some(opts) => InferenceParameters { | ||
n_threads: opts.n_threads.unwrap_or(4), | ||
n_batch: opts.n_batch.unwrap_or(8), | ||
top_k: opts.top_k.unwrap_or(40), | ||
top_p: opts.top_p.unwrap_or(0.95), | ||
temperature: opts.temp.unwrap_or(0.8), | ||
bias_tokens: { | ||
match &opts.bias_tokens { | ||
None => Default::default(), | ||
Some(str) => TokenBias::from_str(str.as_str()) | ||
.map_err(|e| Error::InnerError(e.into()))?, | ||
} | ||
}, | ||
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3), | ||
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512), | ||
}, | ||
}; | ||
let session = &mut self.llm.start_session(Default::default()); | ||
let mut output = String::new(); | ||
session | ||
.infer::<Infallible>( | ||
self.llm.as_ref(), | ||
&mut rand::thread_rng(), | ||
&InferenceRequest { | ||
prompt: prompt.to_text().as_str(), | ||
parameters: Some(¶meters), | ||
// playback_previous_tokens | ||
// maximum_token_count | ||
..Default::default() | ||
}, | ||
// OutputRequest | ||
&mut Default::default(), | ||
|t| { | ||
output.push_str(t); | ||
|
||
Ok(()) | ||
}, | ||
) | ||
.map_err(|e| Error::InnerError(Box::new(e)))?; | ||
|
||
Ok(output.into()) | ||
} | ||
|
||
fn tokens_used( | ||
&self, | ||
options: Option<&Self::PerInvocationOptions>, | ||
prompt: &Prompt, | ||
) -> Result<TokenCount, PromptTokensError> { | ||
let tokenizer = self.get_tokenizer(options)?; | ||
let input = prompt.to_text(); | ||
|
||
let tokens_used = tokenizer | ||
.tokenize_str(&input) | ||
.map_err(|_e| PromptTokensError::UnableToCompute)? | ||
.len() as i32; | ||
let max_tokens = self.max_tokens_allowed(options); | ||
Ok(TokenCount::new(max_tokens, tokens_used)) | ||
} | ||
|
||
fn max_tokens_allowed(&self, _: Option<&Self::PerInvocationOptions>) -> i32 { | ||
self.llm.n_context_tokens().try_into().unwrap_or(2048) | ||
} | ||
|
||
fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> { | ||
None | ||
} | ||
|
||
fn get_tokenizer( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::StepTokenizer<'_>, TokenizerError> { | ||
Ok(LocalLlmTokenizer::new(self)) | ||
} | ||
|
||
fn get_text_splitter( | ||
&self, | ||
_: Option<&Self::PerInvocationOptions>, | ||
) -> Result<Self::TextSplitter<'_>, Self::Error> { | ||
Ok(LocalLlmTextSplitter::new(self)) | ||
} | ||
} | ||
|
||
pub struct LocalLlmTokenizer<'a> { | ||
llm: &'a dyn Model, | ||
} | ||
|
||
impl<'a> LocalLlmTokenizer<'a> { | ||
pub fn new(executor: &'a Executor) -> Self { | ||
LocalLlmTokenizer { | ||
llm: executor.llm.as_ref(), | ||
} | ||
} | ||
} | ||
|
||
impl Tokenizer<llm::TokenId> for LocalLlmTokenizer<'_> { | ||
fn tokenize_str(&self, doc: &str) -> Result<Vec<llm::TokenId>, TokenizerError> { | ||
match &self.llm.vocabulary().tokenize(doc, false) { | ||
Ok(tokens) => Ok(tokens.into_iter().map(|t| t.1).collect()), | ||
Err(_) => Err(TokenizerError::TokenizationError), | ||
} | ||
} | ||
|
||
fn to_string(&self, tokens: Vec<TokenId>) -> Result<String, TokenizerError> { | ||
let mut res = String::new(); | ||
let mut token_utf8_buf = TokenUtf8Buffer::new(); | ||
for token_id in tokens { | ||
// Buffer the token until it's valid UTF-8, then call the callback. | ||
if let Some(tokens) = | ||
token_utf8_buf.push(self.llm.vocabulary().token(token_id as usize)) | ||
{ | ||
res.push_str(&tokens) | ||
} | ||
} | ||
|
||
Ok(res) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod executor; | ||
pub mod options; | ||
mod output; | ||
mod text_splitter; | ||
|
||
pub use executor::Executor; | ||
pub use text_splitter::LocalLlmTextSplitter; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use std::path::PathBuf; | ||
|
||
use llm::{InferenceParameters, ModelParameters}; | ||
use llm_chain::traits::Options; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Default, Serialize, Deserialize)] | ||
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation. | ||
pub struct PerInvocation { | ||
pub n_threads: Option<usize>, | ||
pub n_batch: Option<usize>, | ||
pub n_tok_predict: Option<usize>, | ||
pub top_k: Option<usize>, | ||
pub top_p: Option<f32>, | ||
pub temp: Option<f32>, | ||
/// A comma separated list of token biases. The list should be in the format | ||
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a | ||
/// floating point number. | ||
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 | ||
/// (start of document) and 2 (end of document) to -1.0 which effectively | ||
/// disables the model from generating responses containing those token IDs. | ||
pub bias_tokens: Option<String>, | ||
pub repeat_penalty: Option<f32>, | ||
pub repeat_penalty_last_n: Option<usize>, | ||
} | ||
|
||
impl Options for PerInvocation {} | ||
|
||
impl Into<ModelParameters> for PerInvocation { | ||
fn into(self) -> ModelParameters { | ||
let inference_parameters = InferenceParameters { | ||
n_threads: self.n_threads.unwrap_or(4), | ||
n_batch: self.n_batch.unwrap_or(8), | ||
top_k: self.top_k.unwrap_or(40), | ||
top_p: self.top_p.unwrap_or(0.95), | ||
repeat_penalty: self.temp.unwrap_or(1.3), | ||
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512), | ||
temperature: self.temp.unwrap_or(0.8), | ||
bias_tokens: Default::default(), | ||
}; | ||
|
||
ModelParameters { | ||
prefer_mmap: true, | ||
n_context_tokens: 2048, | ||
inference_parameters, | ||
} | ||
} | ||
} | ||
|
||
/// `PerExecutor` represents a collection of configuration parameters for the executor of the LLM. | ||
/// It contains optional fields for the model path and context parameters. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ``` | ||
/// use llm_chain_local::options::PerExecutor; | ||
/// let executor = PerExecutor::new().with_model_path("path/to/model"); | ||
/// ``` | ||
#[derive(Debug, Clone, Serialize, Deserialize, Default)] | ||
pub struct PerExecutor { | ||
/// Optional path to the LLM. | ||
pub model_path: Option<PathBuf>, | ||
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM. | ||
pub model_type: Option<String>, | ||
} | ||
|
||
impl PerExecutor { | ||
/// Creates a new `PerExecutor` instance with default values. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `PerExecutor` instance with default values for the model path and context parameters. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
|
||
/// Sets the model path for the current `PerExecutor` instance. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `model_path` - The path to the LLM. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A new `PerExecutor` instance with the updated model path. | ||
pub fn with_model_path(mut self, model_path: &str) -> Self { | ||
self.model_path = Some(PathBuf::from(model_path)); | ||
self | ||
} | ||
} | ||
|
||
impl Options for PerExecutor {} |
Oops, something went wrong.