Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local LLM Backend #116

Merged
merged 10 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 341 additions & 4 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"llm-chain-openai",
"llm-chain-llama",
"llm-chain-llama/sys",
"llm-chain-local",
"llm-chain-qdrant",
]

Expand Down
23 changes: 23 additions & 0 deletions llm-chain-local/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "llm-chain-local"
version = "0.9.0"
edition = "2021"
description = "Use `llm-chain` with a local [`llm`](https://github.com/rustformers/llm) backend."
license = "MIT"
keywords = ["llm", "langchain", "ggml", "chain"]
categories = ["science"]
authors = ["Dan Forbes <dan@danforbes.dev>"]
repository = "https://github.com/sobelio/llm-chain/"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.68"
llm = "0.1.1"
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false }
rand = "0.8.5"
serde = { version = "1.0.160", features = ["derive"] }
thiserror = "1.0.40"

[dev-dependencies]
tokio = { version = "1.28.0", features = ["macros", "rt"] }
41 changes: 41 additions & 0 deletions llm-chain-local/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::{env::args, error::Error, path::PathBuf};

use llm_chain::{prompt::Data, traits::Executor};
use llm_chain_local::{options::PerExecutor, Executor as LocalExecutor};

extern crate llm_chain_local;

/// This example demonstrates how to use the llm-chain-local crate to generate text using a model.
///
/// Usage: cargo run --release --package llm-chain-local --example simple model_type path/to/model
///
/// For example, if the model is a LLaMA-type model located at "/models/llama"
/// cargo run --release --package llm-chain-local --example simple llama /models/llama
///
/// An optional third argument can be used to customize the prompt passed to the model.
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let raw_args: Vec<String> = args().collect();
let args = match &raw_args.len() {
3 => (raw_args[1].as_str(), raw_args[2].as_str(), "Rust is a cool programming language because"),
4 => (raw_args[1].as_str(), raw_args[2].as_str(), raw_args[3].as_str()),
_ => panic!("Usage: cargo run --release --example inference <model type> <path to model> <optional prompt>")
};

let model_type = args.0;
let model_path = args.1;
let prompt = args.2;

let exec_opts = PerExecutor {
model_path: Some(PathBuf::from(model_path)),
model_type: Some(String::from(model_type)),
};

let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?;
let res = exec
.execute(None, &Data::Text(String::from(prompt)))
.await?;

println!("{}", res);
Ok(())
}
208 changes: 208 additions & 0 deletions llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
use std::convert::Infallible;
use std::env::var;
use std::path::{Path, PathBuf};
use std::str::FromStr;

use crate::options::{PerExecutor, PerInvocation};
use crate::output::Output;
use crate::LocalLlmTextSplitter;

use async_trait::async_trait;
use llm::{
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture,
TokenBias, TokenId, TokenUtf8Buffer,
};
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};
use thiserror::Error;

/// Executor is responsible for running the LLM and managing its context.
pub struct Executor {
llm: Box<dyn Model>,
}

impl Executor {
pub(crate) fn get_llm(&self) -> &dyn Model {
self.llm.as_ref()
}
}

#[derive(Debug, Error)]
pub enum Error {
#[error("unable to tokenize prompt")]
PromptTokensError(PromptTokensError),
#[error("unable to create executor: {0}")]
InnerError(#[from] Box<dyn std::error::Error>),
}

impl ExecutorError for Error {}

#[async_trait]
impl llm_chain::traits::Executor for Executor {
type PerInvocationOptions = PerInvocation;
type PerExecutorOptions = PerExecutor;
type Output = Output;
type Error = Error;
type Token = i32;
type StepTokenizer<'a> = LocalLlmTokenizer<'a>;
type TextSplitter<'a> = LocalLlmTextSplitter<'a>;

fn new_with_options(
options: Option<Self::PerExecutorOptions>,
invocation_options: Option<Self::PerInvocationOptions>,
) -> Result<Self, ExecutorCreationError> {
let model_type = options
.as_ref()
.and_then(|x| x.model_type.clone())
.or_else(|| var("LLM_MODEL_TYPE").ok())
.ok_or(ExecutorCreationError::FieldRequiredError(
"model_type, ensure to provide the parameter or set `LLM_MODEL_TYPE` environment variable ".to_string(),
))?;
let model_path = options
.as_ref()
.and_then(|x| x.model_path.clone())
.or_else(|| var("LLM_MODEL_PATH").ok().map(|s| PathBuf::from(s)))
.ok_or(ExecutorCreationError::FieldRequiredError(
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(),
))?;

let model_arch = model_type
.parse::<ModelArchitecture>()
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?;
let params = invocation_options.unwrap_or_default().into();
let llm: Box<dyn Model> = llm::load_dynamic(
model_arch,
Path::new(&model_path),
params,
load_progress_callback_stdout,
)
.map_err(|e| ExecutorCreationError::InnerError(Box::new(e)))?;

Ok(Executor { llm })
}

async fn execute(
&self,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<Self::Output, Self::Error> {
let parameters = match options {
None => Default::default(),
Some(opts) => InferenceParameters {
n_threads: opts.n_threads.unwrap_or(4),
n_batch: opts.n_batch.unwrap_or(8),
top_k: opts.top_k.unwrap_or(40),
top_p: opts.top_p.unwrap_or(0.95),
temperature: opts.temp.unwrap_or(0.8),
bias_tokens: {
match &opts.bias_tokens {
None => Default::default(),
Some(str) => TokenBias::from_str(str.as_str())
.map_err(|e| Error::InnerError(e.into()))?,
}
},
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3),
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512),
},
};
let session = &mut self.llm.start_session(Default::default());
let mut output = String::new();
session
.infer::<Infallible>(
self.llm.as_ref(),
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.to_text().as_str(),
parameters: Some(&parameters),
// playback_previous_tokens
// maximum_token_count
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
output.push_str(t);

Ok(())
},
)
.map_err(|e| Error::InnerError(Box::new(e)))?;

Ok(output.into())
}

fn tokens_used(
&self,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<TokenCount, PromptTokensError> {
let tokenizer = self.get_tokenizer(options)?;
let input = prompt.to_text();

let tokens_used = tokenizer
.tokenize_str(&input)
.map_err(|_e| PromptTokensError::UnableToCompute)?
.len() as i32;
let max_tokens = self.max_tokens_allowed(options);
Ok(TokenCount::new(max_tokens, tokens_used))
}

fn max_tokens_allowed(&self, _: Option<&Self::PerInvocationOptions>) -> i32 {
self.llm.n_context_tokens().try_into().unwrap_or(2048)
}

fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
None
}

fn get_tokenizer(
&self,
_: Option<&Self::PerInvocationOptions>,
) -> Result<Self::StepTokenizer<'_>, TokenizerError> {
Ok(LocalLlmTokenizer::new(self))
}

fn get_text_splitter(
&self,
_: Option<&Self::PerInvocationOptions>,
) -> Result<Self::TextSplitter<'_>, Self::Error> {
Ok(LocalLlmTextSplitter::new(self))
}
}

pub struct LocalLlmTokenizer<'a> {
llm: &'a dyn Model,
}

impl<'a> LocalLlmTokenizer<'a> {
pub fn new(executor: &'a Executor) -> Self {
LocalLlmTokenizer {
llm: executor.llm.as_ref(),
}
}
}

impl Tokenizer<llm::TokenId> for LocalLlmTokenizer<'_> {
fn tokenize_str(&self, doc: &str) -> Result<Vec<llm::TokenId>, TokenizerError> {
match &self.llm.vocabulary().tokenize(doc, false) {
Ok(tokens) => Ok(tokens.into_iter().map(|t| t.1).collect()),
Err(_) => Err(TokenizerError::TokenizationError),
}
}

fn to_string(&self, tokens: Vec<TokenId>) -> Result<String, TokenizerError> {
let mut res = String::new();
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(self.llm.vocabulary().token(token_id as usize))
{
res.push_str(&tokens)
}
}

Ok(res)
}
}
7 changes: 7 additions & 0 deletions llm-chain-local/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod executor;
pub mod options;
mod output;
mod text_splitter;

pub use executor::Executor;
pub use text_splitter::LocalLlmTextSplitter;
92 changes: 92 additions & 0 deletions llm-chain-local/src/options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::path::PathBuf;

use llm::{InferenceParameters, ModelParameters};
use llm_chain::traits::Options;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation.
pub struct PerInvocation {
pub n_threads: Option<usize>,
pub n_batch: Option<usize>,
pub n_tok_predict: Option<usize>,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
pub temp: Option<f32>,
/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
pub bias_tokens: Option<String>,
pub repeat_penalty: Option<f32>,
pub repeat_penalty_last_n: Option<usize>,
}

impl Options for PerInvocation {}

impl Into<ModelParameters> for PerInvocation {
fn into(self) -> ModelParameters {
let inference_parameters = InferenceParameters {
n_threads: self.n_threads.unwrap_or(4),
n_batch: self.n_batch.unwrap_or(8),
top_k: self.top_k.unwrap_or(40),
top_p: self.top_p.unwrap_or(0.95),
repeat_penalty: self.temp.unwrap_or(1.3),
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512),
temperature: self.temp.unwrap_or(0.8),
bias_tokens: Default::default(),
};

ModelParameters {
prefer_mmap: true,
n_context_tokens: 2048,
inference_parameters,
}
}
}

/// `PerExecutor` represents a collection of configuration parameters for the executor of the LLM.
/// It contains optional fields for the model path and context parameters.
///
/// # Examples
///
/// ```
/// use llm_chain_local::options::PerExecutor;
/// let executor = PerExecutor::new().with_model_path("path/to/model");
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PerExecutor {
/// Optional path to the LLM.
pub model_path: Option<PathBuf>,
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM.
pub model_type: Option<String>,
Copy link
Contributor

@drager drager May 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ideas to have an enum and a fallback? To let know what models are supported?

I mean something like:

enum ModelType {
  LLama,
  GPTNeoX,
  Custom(String)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model architecture enum isn't currently serializable (although I don't see why it can't be), so I'm leaving this a string for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah alright! I see!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this will be changed soon rustformers/llm#200

}

impl PerExecutor {
/// Creates a new `PerExecutor` instance with default values.
///
/// # Returns
///
/// A `PerExecutor` instance with default values for the model path and context parameters.
pub fn new() -> Self {
Self::default()
}

/// Sets the model path for the current `PerExecutor` instance.
///
/// # Arguments
///
/// * `model_path` - The path to the LLM.
///
/// # Returns
///
/// A new `PerExecutor` instance with the updated model path.
pub fn with_model_path(mut self, model_path: &str) -> Self {
self.model_path = Some(PathBuf::from(model_path));
self
}
}

impl Options for PerExecutor {}
Loading