Skip to content

Commit

Permalink
Update with latest llm release
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 9, 2023
1 parent bcdc265 commit ccf9051
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 42 deletions.
36 changes: 18 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion llm-chain-local/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repository = "https://github.com/sobelio/llm-chain/"

[dependencies]
async-trait = "0.1.68"
llm = "0.1.0-rc3"
llm = "0.1.1"
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false }
rand = "0.8.5"
serde = { version = "1.0.160", features = ["derive"] }
Expand Down
4 changes: 2 additions & 2 deletions llm-chain-local/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{env::args, error::Error};
use std::{env::args, error::Error, path::PathBuf};

use llm_chain::{traits::Executor, prompt::Data};
use llm_chain_local::{Executor as LocalExecutor, options::PerExecutor};
Expand Down Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let prompt = args.2;

let exec_opts = PerExecutor {
model_path: Some(String::from(model_path)),
model_path: Some(PathBuf::from(model_path)),
model_type: Some(String::from(model_type)),
};

Expand Down
44 changes: 36 additions & 8 deletions llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use std::convert::Infallible;
use std::env::var;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::str::FromStr;

use crate::options::{PerExecutor, PerInvocation};
use crate::output::Output;
use crate::LocalLlmTextSplitter;

use async_trait::async_trait;
use llm::{load_progress_callback_stdout, Model, ModelArchitecture, TokenId, TokenUtf8Buffer};
use llm::{
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture,
TokenBias, TokenId, TokenUtf8Buffer,
};
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};
Expand Down Expand Up @@ -58,7 +62,7 @@ impl llm_chain::traits::Executor for Executor {
let model_path = options
.as_ref()
.and_then(|x| x.model_path.clone())
.or_else(|| var("LLM_MODEL_PATH").ok())
.or_else(|| var("LLM_MODEL_PATH").ok().map(|s| PathBuf::from(s)))
.ok_or(ExecutorCreationError::FieldRequiredError(
"model_path, ensure to provide the parameter or set `LLM_MODEL_PATH` environment variable ".to_string(),
))?;
Expand All @@ -80,19 +84,43 @@ impl llm_chain::traits::Executor for Executor {

async fn execute(
&self,
// TODO: call infer_with_params if this is present
_: Option<&Self::PerInvocationOptions>,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<Self::Output, Self::Error> {
let parameters = match options {
None => Default::default(),
Some(opts) => InferenceParameters {
n_threads: opts.n_threads.unwrap_or(4),
n_batch: opts.n_batch.unwrap_or(8),
top_k: opts.top_k.unwrap_or(40),
top_p: opts.top_p.unwrap_or(0.95),
temperature: opts.temp.unwrap_or(0.8),
bias_tokens: {
match &opts.bias_tokens {
None => Default::default(),
Some(str) => TokenBias::from_str(str.as_str())
.map_err(|e| Error::InnerError(e.into()))?,
}
},
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3),
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512),
},
};
let session = &mut self.llm.start_session(Default::default());
let mut output = String::new();
session
.infer::<Infallible>(
self.llm.as_ref(),
prompt.to_text().as_str(),
// EvaluateOutputRequest
&mut Default::default(),
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.to_text().as_str(),
parameters: Some(&parameters),
// playback_previous_tokens
// maximum_token_count
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
output.push_str(t);

Expand Down
31 changes: 18 additions & 13 deletions llm-chain-local/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,48 @@
use llm::{InferenceParameters, InferenceWithPromptParameters, ModelParameters};
use std::path::PathBuf;

use llm::{InferenceParameters, ModelParameters};
use llm_chain::traits::Options;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation.
pub struct PerInvocation {
pub n_threads: Option<usize>,
pub n_batch: Option<usize>,
pub n_tok_predict: Option<usize>,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
pub temp: Option<f32>,
/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
pub bias_tokens: Option<String>,
pub repeat_penalty: Option<f32>,
pub stop_sequence: Option<String>,
pub repeat_penalty_last_n: Option<usize>,
}

impl Options for PerInvocation {}

impl Into<ModelParameters> for PerInvocation {
fn into(self) -> ModelParameters {
let inference_params = InferenceParameters {
let inference_parameters = InferenceParameters {
n_threads: self.n_threads.unwrap_or(4),
n_batch: 8,
n_batch: self.n_batch.unwrap_or(8),
top_k: self.top_k.unwrap_or(40),
top_p: self.top_p.unwrap_or(0.95),
repeat_penalty: self.temp.unwrap_or(1.3),
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512),
temperature: self.temp.unwrap_or(0.8),
bias_tokens: Default::default(),
};

let prompt_params = InferenceWithPromptParameters {
play_back_previous_tokens: false,
maximum_token_count: None,
};

ModelParameters {
prefer_mmap: true,
n_context_tokens: 2048,
inference_params,
inference_prompt_params: prompt_params,
inference_parameters,
}
}
}
Expand All @@ -54,7 +59,7 @@ impl Into<ModelParameters> for PerInvocation {
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PerExecutor {
/// Optional path to the LLM.
pub model_path: Option<String>,
pub model_path: Option<PathBuf>,
/// Optional type (e.g. LLaMA, GPT-Neo-X) of the LLM.
pub model_type: Option<String>,
}
Expand All @@ -79,7 +84,7 @@ impl PerExecutor {
///
/// A new `PerExecutor` instance with the updated model path.
pub fn with_model_path(mut self, model_path: &str) -> Self {
self.model_path = Some(model_path.to_string());
self.model_path = Some(PathBuf::from(model_path));
self
}
}
Expand Down

0 comments on commit ccf9051

Please sign in to comment.