Skip to content

Commit

Permalink
Better handling of feed prompt callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 11, 2023
1 parent 5502f08 commit 82f1c68
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 19 deletions.
14 changes: 13 additions & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl InferenceSession {
parameters,
request.prompt,
output_request,
TokenUtf8Buffer::adapt_callback(&mut callback),
feed_prompt_callback(&mut callback),
)?;
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;
Expand Down Expand Up @@ -634,6 +634,18 @@ pub enum InferenceFeedback {
Halt,
}

/// Adapt an [InferenceResponse] callback so that it can be used in a call to
/// [InferenceSession::feed_prompt].
pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a,
) -> impl FnMut(&[u8]) -> Result<InferenceFeedback, E> + 'a {
let mut buffer = TokenUtf8Buffer::new();
move |token| match buffer.push(token) {
Some(tokens) => callback(InferenceResponse::PromptToken(tokens)),
None => Ok(InferenceFeedback::Continue),
}
}

fn scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession,
feed_prompt_callback, InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, InferenceStats, ModelKVMemoryType, SnapshotError,
};
pub use loader::{
Expand Down
14 changes: 0 additions & 14 deletions crates/llm-base/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ macro_rules! mulf {
use memmap2::{Mmap, MmapAsRawDesc, MmapOptions};
use thiserror::Error;

use crate::inference_session::{InferenceFeedback, InferenceResponse};

/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text.
///
/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8
Expand Down Expand Up @@ -61,18 +59,6 @@ impl TokenUtf8Buffer {
}
}
}

/// Adapt an [InferenceResponse] callback so that it can be used in a `&[u8]`
/// context.
pub fn adapt_callback<'a, E: std::error::Error + 'static>(
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a,
) -> impl FnMut(&[u8]) -> Result<crate::inference_session::InferenceFeedback, E> + 'a {
let mut buffer = Self::new();
move |token| match buffer.push(token) {
Some(tokens) => callback(InferenceResponse::PromptToken(tokens)),
None => Ok(InferenceFeedback::Continue),
}
}
}

#[derive(Error, Debug)]
Expand Down
6 changes: 3 additions & 3 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use llm_base::{
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceStats, LoadProgress,
TokenUtf8Buffer,
feed_prompt_callback, InferenceFeedback, InferenceRequest, InferenceResponse, InferenceStats,
LoadProgress,
};
use rustyline::error::ReadlineError;
use spinoff::{spinners::Dots2, Spinner};
Expand Down Expand Up @@ -50,7 +50,7 @@ fn main() {
&Default::default(),
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
TokenUtf8Buffer::adapt_callback(prompt_callback),
feed_prompt_callback(prompt_callback),
)
.expect("Failed to ingest initial prompt.");

Expand Down

0 comments on commit 82f1c68

Please sign in to comment.