Skip to content

Commit

Permalink
Local LLM Backend (#116)
Browse files Browse the repository at this point in the history
* Local LLM Backend

* Import more, qualify less

* Remove callback from Executor

* Don't publish load function

* Use latest `llm` with better model loading

* Implement naive `answer_prefix`

* Update with latest llm release

* Fix CI
  • Loading branch information
danforbes authored May 9, 2023
1 parent 8856be8 commit cb8c4f3
Show file tree
Hide file tree
Showing 9 changed files with 833 additions and 4 deletions.
345 changes: 341 additions & 4 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"llm-chain-openai",
"llm-chain-llama",
"llm-chain-llama/sys",
"llm-chain-local",
"llm-chain-qdrant",
]

Expand Down
23 changes: 23 additions & 0 deletions llm-chain-local/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "llm-chain-local"
version = "0.9.0"
edition = "2021"
description = "Use `llm-chain` with a local [`llm`](https://github.com/rustformers/llm) backend."
license = "MIT"
keywords = ["llm", "langchain", "ggml", "chain"]
categories = ["science"]
authors = ["Dan Forbes <dan@danforbes.dev>"]
repository = "https://github.com/sobelio/llm-chain/"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.68"
llm = "0.1.1"
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false }
rand = "0.8.5"
serde = { version = "1.0.160", features = ["derive"] }
thiserror = "1.0.40"

[dev-dependencies]
tokio = { version = "1.28.0", features = ["macros", "rt"] }
41 changes: 41 additions & 0 deletions llm-chain-local/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::{env::args, error::Error, path::PathBuf};

use llm_chain::{prompt::Data, traits::Executor};
use llm_chain_local::{options::PerExecutor, Executor as LocalExecutor};

extern crate llm_chain_local;

/// This example demonstrates how to use the llm-chain-local crate to generate text using a model.
///
/// Usage: cargo run --release --package llm-chain-local --example simple model_type path/to/model
///
/// For example, if the model is a LLaMA-type model located at "/models/llama"
/// cargo run --release --package llm-chain-local --example simple llama /models/llama
///
/// An optional third argument can be used to customize the prompt passed to the model.
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let raw_args: Vec<String> = args().collect();
let args = match &raw_args.len() {
3 => (raw_args[1].as_str(), raw_args[2].as_str(), "Rust is a cool programming language because"),
4 => (raw_args[1].as_str(), raw_args[2].as_str(), raw_args[3].as_str()),
_ => panic!("Usage: cargo run --release --example inference <model type> <path to model> <optional prompt>")
};

let model_type = args.0;
let model_path = args.1;
let prompt = args.2;

let exec_opts = PerExecutor {
model_path: Some(PathBuf::from(model_path)),
model_type: Some(String::from(model_type)),
};

let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?;
let res = exec
.execute(None, &Data::Text(String::from(prompt)))
.await?;

println!("{}", res);
Ok(())
}
208 changes: 208 additions & 0 deletions llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
use std::convert::Infallible;
use std::env::var;
use std::path::{Path, PathBuf};
use std::str::FromStr;

use crate::options::{PerExecutor, PerInvocation};
use crate::output::Output;
use crate::LocalLlmTextSplitter;

use async_trait::async_trait;
use llm::{
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture,
TokenBias, TokenId, TokenUtf8Buffer,
};
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};
use thiserror::Error;

/// Executor is responsible for running the LLM and managing its context.
pub struct Executor {
llm: Box<dyn Model>,
}

impl Executor {
pub(crate) fn get_llm(&self) -> &dyn Model {
self.llm.as_ref()
}
}

#[derive(Debug, Error)]
pub enum Error {
#[error("unable to tokenize prompt")]
PromptTokensError(PromptTokensError),
#[error("unable to create executor: {0}")]
InnerError(#[from] Box<dyn std::error::Error>),
}

impl ExecutorError for Error {}

#[async_trait]
impl llm_chain::traits::Executor for Executor {
type PerInvocationOptions = PerInvocation;
type PerExecutorOptions = PerExecutor;
type Output = Output;
type Error = Error;
type Token = i32;
type StepTokenizer<'a> = LocalLlmTokenizer<'a>;
type TextSplitter<'a> = LocalLlmTextSplitter<'a>;

fn new_with_options(
options: Option<Self::PerExecutorOptions>,
invocation_options: Option<Self::PerInvocationOptions>,
) -> Result<Self, ExecutorCreationError> {
let model_type = options
.as_ref()
.and_then(|x| x.model_type.clone())
.or_else(|| var("LLM_MODEL_TYPE").ok())
.ok_or(ExecutorCreationError::FieldRequiredError(
"model_type, ensure to provide the parameter or set `LLM_MODEL_TYPE` environment variable ".to_string(),
))?;
let model_path = options
.as_ref()
.and_then(|x| x.model_path.clone())
.or_else(|| var("LLM_MODEL_PATH").ok().map(|s| PathBuf::from(s)))
.ok_or(ExecutorCreationError::FieldRequiredError(
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(),
))?;

let model_arch = model_type
.parse::<ModelArchitecture>()
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?;
let params = invocation_options.unwrap_or_default().into();
let llm: Box<dyn Model> = llm::load_dynamic(
model_arch,
Path::new(&model_path),
params,
load_progress_callback_stdout,
)
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?;

Ok(Executor { llm })
}

async fn execute(
&self,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<Self::Output, Self::Error> {
let parameters = match options {
None => Default::default(),
Some(opts) => InferenceParameters {
n_threads: opts.n_threads.unwrap_or(4),
n_batch: opts.n_batch.unwrap_or(8),
top_k: opts.top_k.unwrap_or(40),
top_p: opts.top_p.unwrap_or(0.95),
temperature: opts.temp.unwrap_or(0.8),
bias_tokens: {
match &opts.bias_tokens {
None => Default::default(),
Some(str) => TokenBias::from_str(str.as_str())
.map_err(|e| Error::InnerError(e.into()))?,
}
},
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3),
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512),
},
};
let session = &mut self.llm.start_session(Default::default());
let mut output = String::new();
session
.infer::<Infallible>(
self.llm.as_ref(),
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.to_text().as_str(),
parameters: Some(&parameters),
// playback_previous_tokens
// maximum_token_count
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
output.push_str(t);

Ok(())
},
)
.map_err(|e| Error::InnerError(Box::new(e)))?;

Ok(output.into())
}

fn tokens_used(
&self,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
let tokenizer = self.get_tokenizer(options)?;
let input = prompt.to_text();

let tokens_used = tokenizer
.tokenize_str(&input)
.map_err(|_e| PromptTokensError::UnableToCompute)?
.len() as i32;
let max_tokens = self.max_tokens_allowed(options);
Ok(TokenCount::new(max_tokens, tokens_used))
}

fn max_tokens_allowed(&self, _: Option<&Self::PerInvocationOptions>) -> i32 {
self.llm.n_context_tokens().try_into().unwrap_or(2048)
}

fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
None
}

fn get_tokenizer(
&self,
_: Option<&Self::PerInvocationOptions>,
) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
Ok(LocalLlmTokenizer::new(self))
}

fn get_text_splitter(
&self,
_: Option<&Self::PerInvocationOptions>,
) -> Result<Self::TextSplitter<'_>, Self::Error> {
Ok(LocalLlmTextSplitter::new(self))
}
}

pub struct LocalLlmTokenizer<'a> {
llm: &'a dyn Model,
}

impl<'a> LocalLlmTokenizer<'a> {
pub fn new(executor: &'a Executor) -> Self {
LocalLlmTokenizer {
llm: executor.llm.as_ref(),
}
}
}

impl Tokenizer<llm::TokenId> for LocalLlmTokenizer<'_> {
fn tokenize_str(&self, doc: &str) -> Result<Vec<llm::TokenId>, TokenizerError> {
match &self.llm.vocabulary().tokenize(doc, false) {
Ok(tokens) => Ok(tokens.into_iter().map(|t| t.1).collect()),
Err(_) => Err(TokenizerError::TokenizationError),
}
}

fn to_string(&self, tokens: Vec<TokenId>) -> Result<String, TokenizerError> {
let mut res = String::new();
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(self.llm.vocabulary().token(token_id as usize))
{
res.push_str(&tokens)
}
}

Ok(res)
}
}
7 changes: 7 additions & 0 deletions llm-chain-local/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod executor;
pub mod options;
mod output;
mod text_splitter;

pub use executor::Executor;
pub use text_splitter::LocalLlmTextSplitter;
92 changes: 92 additions & 0 deletions llm-chain-local/src/options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::path::PathBuf;

use llm::{InferenceParameters, ModelParameters};
use llm_chain::traits::Options;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation.
pub struct PerInvocation {
pub n_threads: Option<usize>,
pub n_batch: Option<usize>,
pub n_tok_predict: Option<usize>,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
pub temp: Option<f32>,
/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
pub bias_tokens: Option<String>,
pub repeat_penalty: Option<f32>,
pub repeat_penalty_last_n: Option<usize>,
}

impl Options for PerInvocation {}

impl Into<ModelParameters> for PerInvocation {
fn into(self) -> ModelParameters {
let inference_parameters = InferenceParameters {
n_threads: self.n_threads.unwrap_or(4),
n_batch: self.n_batch.unwrap_or(8),
top_k: self.top_k.unwrap_or(40),
top_p: self.top_p.unwrap_or(0.95),
repeat_penalty: self.temp.unwrap_or(1.3),
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512),
temperature: self.temp.unwrap_or(0.8),
bias_tokens: Default::default(),
};

ModelParameters {
prefer_mmap: true,
n_context_tokens: 2048,
inference_parameters,
}
}
}

/// `PerExecutor` represents a collection of configuration parameters for the executor of the LLM.
/// It contains optional fields for the model path and context parameters.
///
/// # Examples
///
/// ```
/// use llm_chain_local::options::PerExecutor;
/// let executor = PerExecutor::new().with_model_path("path/to/model");
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PerExecutor {
/// Optional path to the LLM.
pub model_path: Option<PathBuf>,
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM.
pub model_type: Option<String>,
}

impl PerExecutor {
/// Creates a new `PerExecutor` instance with default values.
///
/// # Returns
///
/// A `PerExecutor` instance with default values for the model path and context parameters.
pub fn new() -> Self {
Self::default()
}

/// Sets the model path for the current `PerExecutor` instance.
///
/// # Arguments
///
/// * `model_path` - The path to the LLM.
///
/// # Returns
///
/// A new `PerExecutor` instance with the updated model path.
pub fn with_model_path(mut self, model_path: &str) -> Self {
self.model_path = Some(PathBuf::from(model_path));
self
}
}

impl Options for PerExecutor {}
Loading

0 comments on commit cb8c4f3

Please sign in to comment.