Skip to content

Commit

Permalink
Broken streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes committed May 7, 2023
1 parent b7be6d3 commit 36b2b23
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions llm-chain-local/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ repository = "https://github.com/sobelio/llm-chain/"

[dependencies]
async-trait = "0.1.68"
futures = "0.3.28"
llm = "0.1.0-rc3"
llm-chain = { path = "../llm-chain", version = "0.9.0", default-features = false }
rand = "0.8.5"
serde = { version = "1.0.160", features = ["derive"] }
thiserror = "1.0.40"

[dev-dependencies]
tokio = { version = "1.28.0", features = ["macros", "rt"] }

5 changes: 4 additions & 1 deletion llm-chain-local/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ async fn main() -> Result<(), Box<dyn Error>> {

let exec = LocalExecutor::new_with_options(Some(exec_opts), None)?;
let res = exec.execute(None, &Data::Text(String::from(prompt))).await?;

while let Some(chunk) = res.receiver.lock().unwrap().recv().await {
print!("{chunk}")
}

println!("{}", res);
Ok(())
}
15 changes: 9 additions & 6 deletions llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env::var;
use std::path::Path;

use crate::options::{PerExecutor, PerInvocation};
use crate::output::Output;
use crate::output::{Output, StreamingOutput};
use crate::LocalLlmTextSplitter;

use async_trait::async_trait;
Expand Down Expand Up @@ -38,7 +38,7 @@ impl ExecutorError for Error {}
impl llm_chain::traits::Executor for Executor {
type PerInvocationOptions = PerInvocation;
type PerExecutorOptions = PerExecutor;
type Output = Output;
type Output = StreamingOutput;
type Error = Error;
type Token = i32;
type StepTokenizer<'a> = LocalLlmTokenizer<'a>;
Expand Down Expand Up @@ -85,23 +85,26 @@ impl llm_chain::traits::Executor for Executor {
prompt: &Prompt,
) -> Result<Self::Output, Self::Error> {
let session = &mut self.llm.start_session(Default::default());
let mut output = String::new();

let (send, recv) = tokio::sync::mpsc::channel::<String>(8);
let mut stream_out = StreamingOutput::new(recv);
session
.infer::<Infallible>(
.infer::<tokio::sync::mpsc::error::TrySendError<String>>(
self.llm.as_ref(),
prompt.to_text().as_str(),
// EvaluateOutputRequest
&mut Default::default(),
&mut rand::thread_rng(),
|t| {
output.push_str(t);
stream_out.add_chunk(String::from(t));
send.blocking_send(String::from(t))?;

Ok(())
},
)
.map_err(|e| Error::InnerError(Box::new(e)))?;

Ok(output.into())
Ok(stream_out)
}

fn tokens_used(
Expand Down
22 changes: 8 additions & 14 deletions llm-chain-local/src/output.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_trait::async_trait;
use llm_chain::output;
use std::{fmt::{Display, Formatter}, sync::{Arc, Mutex, mpsc}};
use std::{fmt::{Display, Formatter}, sync::{Arc, Mutex}};

/// Represents the output from the LLAMA model.
#[derive(Debug, Clone)]
Expand All @@ -10,34 +10,28 @@ pub struct Output {

#[derive(Clone, Debug)]
pub struct StreamingOutput {
complete: bool,
output: Arc<Mutex<Vec<String>>>,
receiver: Arc<Mutex<futures::channel::mpsc::Receiver<String>>>
pub receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<String>>>
}

impl StreamingOutput {
fn new() -> Self {
let (_, receiver) = futures::channel::mpsc::channel(8);
pub(crate) fn new(recv: tokio::sync::mpsc::Receiver<String>) -> Self {
Self {
complete: false,
output: Arc::new(Mutex::new(vec![])),
receiver: Arc::new(Mutex::new(receiver))
receiver: Arc::new(Mutex::new(recv))
}
}

fn set_complete(&mut self) {
self.complete = true;
}

fn complete(&self) -> bool {
self.complete
pub(crate) fn add_chunk(&mut self, chunk: String) {
let mut chunks = self.output.lock().unwrap();
chunks.push(chunk);
}
}

#[async_trait]
impl output::Output for StreamingOutput {
async fn primary_textual_output_choices(&self) -> Vec<String> {
todo!()
self.output.lock().unwrap().clone()
}
}

Expand Down

0 comments on commit 36b2b23

Please sign in to comment.