Skip to content

Commit

Permalink
Update with latest llm release
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 9, 2023
1 parent bcdc265 commit 9c05a5d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 36 deletions.
36 changes: 18 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion llm-chain-local/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repository = "https://github.com/sobelio/llm-chain/"

[dependencies]
async-trait = "0.1.68"
llm = "0.1.0-rc3"
llm = "0.1.1"
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false }
rand = "0.8.5"
serde = { version = "1.0.160", features = ["derive"] }
Expand Down
40 changes: 34 additions & 6 deletions llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use std::convert::Infallible;
use std::env::var;
use std::path::Path;
use std::str::FromStr;

use crate::options::{PerExecutor, PerInvocation};
use crate::output::Output;
use crate::LocalLlmTextSplitter;

use async_trait::async_trait;
use llm::{load_progress_callback_stdout, Model, ModelArchitecture, TokenId, TokenUtf8Buffer};
use llm::{
load_progress_callback_stdout, InferenceParameters, InferenceRequest, Model, ModelArchitecture,
TokenBias, TokenId, TokenUtf8Buffer,
};
use llm_chain::prompt::Prompt;
use llm_chain::tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError};
use llm_chain::traits::{ExecutorCreationError, ExecutorError};
Expand Down Expand Up @@ -80,19 +84,43 @@ impl llm_chain::traits::Executor for Executor {

async fn execute(
&self,
// TODO: call infer_with_params if this is present
_: Option<&Self::PerInvocationOptions>,
options: Option<&Self::PerInvocationOptions>,
prompt: &Prompt,
) -> Result<Self::Output, Self::Error> {
let parameters = match options {
None => Default::default(),
Some(opts) => InferenceParameters {
n_threads: opts.n_threads.unwrap_or(4),
n_batch: opts.n_batch.unwrap_or(8),
top_k: opts.top_k.unwrap_or(40),
top_p: opts.top_p.unwrap_or(0.95),
temperature: opts.temp.unwrap_or(0.8),
bias_tokens: {
match &opts.bias_tokens {
None => Default::default(),
Some(str) => TokenBias::from_str(str.as_str())
.map_err(|e| Error::InnerError(e.into()))?,
}
},
repeat_penalty: opts.repeat_penalty.unwrap_or(1.3),
repetition_penalty_last_n: opts.repeat_penalty_last_n.unwrap_or(512),
},
};
let session = &mut self.llm.start_session(Default::default());
let mut output = String::new();
session
.infer::<Infallible>(
self.llm.as_ref(),
prompt.to_text().as_str(),
// EvaluateOutputRequest
&mut Default::default(),
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.to_text().as_str(),
parameters: Some(&parameters),
// playback_previous_tokens
// maximum_token_count
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
output.push_str(t);

Expand Down
25 changes: 14 additions & 11 deletions llm-chain-local/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
use llm::{InferenceParameters, InferenceWithPromptParameters, ModelParameters};
use llm::{InferenceParameters, ModelParameters};
use llm_chain::traits::Options;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
/// An overridable collection of configuration parameters for an LLM. It is combined with a prompt to create an invocation.
pub struct PerInvocation {
pub n_threads: Option<usize>,
pub n_batch: Option<usize>,
pub n_tok_predict: Option<usize>,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
pub temp: Option<f32>,
/// A comma separated list of token biases. The list should be in the format
/// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a
/// floating point number.
/// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1
/// (start of document) and 2 (end of document) to -1.0 which effectively
/// disables the model from generating responses containing those token IDs.
pub bias_tokens: Option<String>,
pub repeat_penalty: Option<f32>,
pub stop_sequence: Option<String>,
pub repeat_penalty_last_n: Option<usize>,
}

impl Options for PerInvocation {}

impl Into<ModelParameters> for PerInvocation {
fn into(self) -> ModelParameters {
let inference_params = InferenceParameters {
let inference_parameters = InferenceParameters {
n_threads: self.n_threads.unwrap_or(4),
n_batch: 8,
n_batch: self.n_batch.unwrap_or(8),
top_k: self.top_k.unwrap_or(40),
top_p: self.top_p.unwrap_or(0.95),
repeat_penalty: self.temp.unwrap_or(1.3),
repetition_penalty_last_n: self.repeat_penalty_last_n.unwrap_or(512),
temperature: self.temp.unwrap_or(0.8),
bias_tokens: Default::default(),
};

let prompt_params = InferenceWithPromptParameters {
play_back_previous_tokens: false,
maximum_token_count: None,
};

ModelParameters {
prefer_mmap: true,
n_context_tokens: 2048,
inference_params,
inference_prompt_params: prompt_params,
inference_parameters,
}
}
}
Expand Down

0 comments on commit 9c05a5d

Please sign in to comment.