Skip to content

Commit 4eb0bc9

Browse files
committed
fix: start_completing should not be invoked on a per-iteration basis
There's still some UB that can be triggered due to llama.cpp's threading model, which needs patching up.
1 parent 94d7385 commit 4eb0bc9

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

Cargo.lock

Lines changed: 74 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/llama_cpp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ num_cpus = "1.16.0"
2222
thiserror = "1.0.49"
2323
tokio = { version = "1.33.0", features = ["sync"] }
2424
tracing = "0.1.39"
25+
tracing-subscriber = { version = "0.3.17", features = ["fmt"] }

crates/llama_cpp/src/lib.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
//!
2727
//! // `ctx.get_completions` creates a worker thread that generates tokens. When the completion
2828
//! // handle is dropped, tokens stop generating!
29-
//! while let Some(next_token) = ctx.get_completions().next_token() {
29+
//!
30+
//! let mut completions = ctx.start_completing();
31+
//!
32+
//! while let Some(next_token) = completions.next_token() {
3033
//! println!("{}", String::from_utf8_lossy(next_token.as_bytes()));
3134
//!
3235
//! decoded_tokens += 1;
@@ -315,7 +318,7 @@ impl LlamaModel {
315318
//
316319
// `out_buf` is a `Vec<Token>`, and `Token` is `#[repr(transparent)]` over an `i32`.
317320
llama_tokenize(
318-
**self.model.blocking_read(),
321+
**self.model.try_read().unwrap(),
319322
content.as_ptr() as *const i8,
320323
content.len() as i32,
321324
out_buf.as_mut_ptr() as *mut i32,
@@ -352,7 +355,7 @@ impl LlamaModel {
352355
token.0
353356
);
354357

355-
unsafe { CStr::from_ptr(llama_token_get_text(**self.model.blocking_read(), token.0)) }.to_bytes()
358+
unsafe { CStr::from_ptr(llama_token_get_text(**self.model.try_read().unwrap(), token.0)) }.to_bytes()
356359
}
357360

358361
/// Creates a new evaluation context for this model.
@@ -581,7 +584,7 @@ impl LlamaSession {
581584

582585
/// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
583586
/// This is where you want to be if you just want some completions.
584-
pub fn get_completions(&mut self) -> CompletionHandle {
587+
pub fn start_completing(&mut self) -> CompletionHandle {
585588
let (tx, rx) = flume::unbounded();
586589

587590
info!(
@@ -599,7 +602,7 @@ impl LlamaSession {
599602
Box::leak(Box::new(detail::BeamSearchState { tx })) as *mut _ as *mut c_void,
600603
1,
601604
past_tokens as i32,
602-
2048,
605+
32_768,
603606
);
604607
});
605608

@@ -657,7 +660,7 @@ pub struct CompletionHandle<'a> {
657660
impl<'a> CompletionHandle<'a> {
658661
/// Blocks the current thread, resolving to the next completed token, or `None` if EOS is
659662
/// reached.
660-
pub fn next_token(&self) -> Option<CompletionToken<'_>> {
663+
pub fn next_token(&mut self) -> Option<CompletionToken<'_>> {
661664
self.rx.recv().ok().map(|token| CompletionToken {
662665
ctx: self.ctx,
663666
token,
@@ -666,7 +669,7 @@ impl<'a> CompletionHandle<'a> {
666669

667670
/// Asynchronously yields the current thread, resolving to the next completed token, or `None`
668671
/// if EOS is reached.
669-
pub async fn next_token_async(&self) -> Option<CompletionToken<'_>> {
672+
pub async fn next_token_async(&mut self) -> Option<CompletionToken<'_>> {
670673
self.rx
671674
.recv_async()
672675
.await

0 commit comments

Comments
 (0)