Skip to content

Commit 1019402

Browse files
committed
feat: more async function variants
1 parent c190df6 commit 1019402

File tree

4 files changed

+107
-30
lines changed

4 files changed

+107
-30
lines changed

Cargo.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ let mut decoded_tokens = 0;
2626
// handle is dropped, tokens stop generating!
2727
let mut completions = ctx.get_completions();
2828

29-
while let Some(next_token) = completions.next_token() {
30-
println!("{}", String::from_utf8_lossy(next_token.as_bytes()));
29+
while let Some(next_token) = completions.detokenize() {
30+
println!("{}", String::from_utf8_lossy(&*next_token.as_bytes()));
3131
decoded_tokens += 1;
3232
if decoded_tokens > max_tokens {
3333
break;

crates/llama_cpp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ flume = "0.11.0"
1616
llama_cpp_sys = { version = "^0.2.1", path = "../llama_cpp_sys" }
1717
num_cpus = "1.16.0"
1818
thiserror = "1.0.50"
19+
tinyvec = { version = "1.6.0", features = ["alloc"] }
1920
tokio = { version = "1.33.0", features = ["sync", "rt"] }
2021
tracing = "0.1.39"

crates/llama_cpp/src/lib.rs

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
//! let mut completions = ctx.start_completing();
3131
//!
3232
//! while let Some(next_token) = completions.next_token() {
33-
//! println!("{}", String::from_utf8_lossy(next_token.as_bytes()));
33+
//! println!("{}", String::from_utf8_lossy(&*next_token.detokenize()));
3434
//!
3535
//! decoded_tokens += 1;
3636
//!
@@ -74,10 +74,13 @@
7474
//! [llama.cpp]: https://github.com/ggerganov/llama.cpp/
7575
7676
#![warn(missing_docs)]
77+
7778
use std::ffi::{c_void, CStr, CString};
7879
use std::path::{Path, PathBuf};
80+
use std::sync::atomic::{AtomicUsize, Ordering};
7981
use std::sync::Arc;
8082
use std::{ptr, thread};
83+
use tinyvec::TinyVec;
8184
use tokio::sync::{Mutex, RwLock};
8285

8386
use ctor::{ctor, dtor};
@@ -184,6 +187,7 @@ pub struct LlamaInternalError;
184187
struct LlamaModelInner(*mut llama_model);
185188

186189
unsafe impl Send for LlamaModelInner {}
190+
187191
unsafe impl Sync for LlamaModelInner {}
188192

189193
impl Drop for LlamaModelInner {
@@ -297,7 +301,9 @@ impl LlamaModel {
297301
pub async fn load_from_file_async(file_path: impl AsRef<Path>) -> Result<Self, LlamaLoadError> {
298302
let path = file_path.as_ref().to_owned();
299303

300-
tokio::task::spawn_blocking(move || Self::load_from_file(path)).await.unwrap()
304+
tokio::task::spawn_blocking(move || Self::load_from_file(path))
305+
.await
306+
.unwrap()
301307
}
302308

303309
/// Converts `content` into a vector of tokens that are valid input for this model.
@@ -364,7 +370,13 @@ impl LlamaModel {
364370
token.0
365371
);
366372

367-
unsafe { CStr::from_ptr(llama_token_get_text(**self.model.try_read().unwrap(), token.0)) }.to_bytes()
373+
unsafe {
374+
CStr::from_ptr(llama_token_get_text(
375+
**self.model.try_read().unwrap(),
376+
token.0,
377+
))
378+
}
379+
.to_bytes()
368380
}
369381

370382
/// Creates a new evaluation context for this model.
@@ -384,7 +396,7 @@ impl LlamaModel {
384396
let ctx = unsafe {
385397
// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
386398
// for at least the lifetime of `LlamaContext`.
387-
llama_new_context_with_model(**self.model.blocking_read(), params)
399+
llama_new_context_with_model(**self.model.try_read().unwrap(), params)
388400
};
389401

390402
let cpus = num_cpus::get() as u32;
@@ -396,13 +408,14 @@ impl LlamaModel {
396408
}
397409

398410
LlamaSession {
399-
model: self.clone(),
400-
inner: Arc::new(Mutex::new(LlamaContextInner { ptr: ctx }) ),
401-
history_size: 0,
411+
inner: Arc::new(LlamaSessionInner {
412+
model: self.clone(),
413+
ctx: Mutex::new(LlamaContextInner { ptr: ctx }),
414+
history_size: AtomicUsize::new(0),
415+
}),
402416
}
403417
}
404418

405-
406419
/// Returns the beginning of sentence (BOS) token for this context.
407420
pub fn bos(&self) -> Token {
408421
self.bos_token
@@ -448,6 +461,7 @@ struct LlamaContextInner {
448461
}
449462

450463
unsafe impl Send for LlamaContextInner {}
464+
451465
unsafe impl Sync for LlamaContextInner {}
452466

453467
impl Drop for LlamaContextInner {
@@ -464,15 +478,21 @@ impl Drop for LlamaContextInner {
464478
///
465479
/// This stores a small amount of state, which is destroyed when the session is dropped.
466480
/// You can create an arbitrary number of sessions for a model using [`LlamaModel::create_session`].
481+
#[derive(Clone)]
467482
pub struct LlamaSession {
483+
inner: Arc<LlamaSessionInner>,
484+
}
485+
486+
/// The cloned part of a [`LlamaSession`].
487+
struct LlamaSessionInner {
468488
/// The model this session was created from.
469489
model: LlamaModel,
470490

471491
/// A pointer to the llama.cpp side of the model context.
472-
inner: Arc<Mutex<LlamaContextInner>>,
492+
ctx: Mutex<LlamaContextInner>,
473493

474494
/// The number of tokens present in this model's context.
475-
history_size: usize,
495+
history_size: AtomicUsize,
476496
}
477497

478498
/// An error raised while advancing the context in a [`LlamaSession`].
@@ -508,7 +528,7 @@ impl LlamaSession {
508528
///
509529
/// The model will generate new tokens from the end of the context.
510530
pub fn advance_context_with_tokens(
511-
&mut self,
531+
&self,
512532
tokens: impl AsRef<[Token]>,
513533
) -> Result<(), LlamaContextError> {
514534
let tokens = tokens.as_ref();
@@ -562,7 +582,7 @@ impl LlamaSession {
562582
if unsafe {
563583
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
564584
// initialized above.
565-
llama_decode(self.inner.blocking_lock().ptr, batch)
585+
llama_decode(self.inner.ctx.blocking_lock().ptr, batch)
566586
} != 0
567587
{
568588
return Err(LlamaInternalError.into());
@@ -577,40 +597,78 @@ impl LlamaSession {
577597
llama_batch_free(batch)
578598
};
579599

580-
self.history_size += tokens.len();
600+
self.inner
601+
.history_size
602+
.fetch_add(n_tokens, Ordering::SeqCst);
581603

582604
Ok(())
583605
}
584606

607+
/// Advances the inner context of this model with `tokens`.
608+
///
609+
/// This is a thin `tokio::spawn_blocking` wrapper around
610+
/// [`LlamaSession::advance_context_with_tokens`].
611+
pub async fn advance_context_with_tokens_async(
612+
&mut self,
613+
tokens: impl AsRef<[Token]>,
614+
) -> Result<(), LlamaContextError> {
615+
let tokens = tokens.as_ref().to_owned();
616+
let session = self.clone();
617+
618+
tokio::task::spawn_blocking(move || session.advance_context_with_tokens(tokens))
619+
.await
620+
.unwrap()
621+
}
622+
585623
/// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
586624
///
587625
/// `ctx` is typically a UTF-8 string, but anything that can be downcast to bytes is accepted.
588626
pub fn advance_context(&mut self, ctx: impl AsRef<[u8]>) -> Result<(), LlamaContextError> {
589-
let tokens = self.model.tokenize_bytes(ctx.as_ref())?.into_boxed_slice();
627+
let tokens = self
628+
.inner
629+
.model
630+
.tokenize_bytes(ctx.as_ref())?
631+
.into_boxed_slice();
590632

591633
self.advance_context_with_tokens(tokens)
592634
}
593635

636+
/// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
637+
///
638+
/// This is a thin `tokio::spawn_blocking` wrapper around
639+
/// [`LlamaSession::advance_context`].
640+
pub async fn advance_context_async(
641+
&self,
642+
ctx: impl AsRef<[u8]>,
643+
) -> Result<(), LlamaContextError> {
644+
let ctx = ctx.as_ref().to_owned();
645+
let session = self.clone();
646+
647+
tokio::task::spawn_blocking(move || {
648+
let tokens = session.inner.model.tokenize_bytes(ctx)?.into_boxed_slice();
649+
650+
session.advance_context_with_tokens(tokens)
651+
})
652+
.await
653+
.unwrap()
654+
}
655+
594656
/// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
595657
/// This is where you want to be if you just want some completions.
596658
pub fn start_completing(&mut self) -> CompletionHandle {
597659
let (tx, rx) = flume::unbounded();
660+
let history_size = self.inner.history_size.load(Ordering::SeqCst);
661+
let session = self.clone();
598662

599-
info!(
600-
"Generating completions with {} tokens of history",
601-
self.history_size,
602-
);
603-
604-
let past_tokens = self.history_size;
605-
let mutex = self.inner.clone();
663+
info!("Generating completions with {history_size} tokens of history");
606664

607665
thread::spawn(move || unsafe {
608666
llama_beam_search(
609-
mutex.blocking_lock().ptr,
667+
session.inner.ctx.blocking_lock().ptr,
610668
Some(detail::llama_beam_search_callback),
611669
Box::leak(Box::new(detail::BeamSearchState { tx })) as *mut _ as *mut c_void,
612670
1,
613-
past_tokens as i32,
671+
history_size as i32,
614672
32_768,
615673
);
616674
});
@@ -620,7 +678,7 @@ impl LlamaSession {
620678

621679
/// Returns the model this session was created from.
622680
pub fn model(&self) -> LlamaModel {
623-
self.model.clone()
681+
self.inner.model.clone()
624682
}
625683
}
626684

@@ -634,9 +692,11 @@ pub struct CompletionToken<'a> {
634692
}
635693

636694
impl<'a> CompletionToken<'a> {
637-
/// Decodes this token, returning the bytes composing it.
638-
pub fn as_bytes(&self) -> &[u8] {
639-
self.ctx.model.detokenize(self.token)
695+
/// Decodes this token, returning the bytes it is composed of.
696+
pub fn detokenize(&self) -> TinyVec<[u8; 8]> {
697+
let model = self.ctx.model();
698+
699+
model.detokenize(self.token).into()
640700
}
641701

642702
/// Returns this token as an `i32`.
@@ -735,7 +795,7 @@ mod detail {
735795
// SAFETY: beam_views[i] exists where 0 <= i <= n_beams.
736796
*beam_state.beam_views.add(i)
737797
}
738-
.eob = true;
798+
.eob = true;
739799
}
740800
}
741801

0 commit comments

Comments
 (0)