30
30
//! let mut completions = ctx.start_completing();
31
31
//!
32
32
//! while let Some(next_token) = completions.next_token() {
33
- //! println!("{}", String::from_utf8_lossy(next_token.as_bytes ()));
33
+ //! println!("{}", String::from_utf8_lossy(&* next_token.detokenize ()));
34
34
//!
35
35
//! decoded_tokens += 1;
36
36
//!
74
74
//! [llama.cpp]: https://github.com/ggerganov/llama.cpp/
75
75
76
76
#![ warn( missing_docs) ]
77
+
77
78
use std:: ffi:: { c_void, CStr , CString } ;
78
79
use std:: path:: { Path , PathBuf } ;
80
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
79
81
use std:: sync:: Arc ;
80
82
use std:: { ptr, thread} ;
83
+ use tinyvec:: TinyVec ;
81
84
use tokio:: sync:: { Mutex , RwLock } ;
82
85
83
86
use ctor:: { ctor, dtor} ;
@@ -184,6 +187,7 @@ pub struct LlamaInternalError;
184
187
struct LlamaModelInner ( * mut llama_model ) ;
185
188
186
189
unsafe impl Send for LlamaModelInner { }
190
+
187
191
unsafe impl Sync for LlamaModelInner { }
188
192
189
193
impl Drop for LlamaModelInner {
@@ -297,7 +301,9 @@ impl LlamaModel {
297
301
pub async fn load_from_file_async ( file_path : impl AsRef < Path > ) -> Result < Self , LlamaLoadError > {
298
302
let path = file_path. as_ref ( ) . to_owned ( ) ;
299
303
300
- tokio:: task:: spawn_blocking ( move || Self :: load_from_file ( path) ) . await . unwrap ( )
304
+ tokio:: task:: spawn_blocking ( move || Self :: load_from_file ( path) )
305
+ . await
306
+ . unwrap ( )
301
307
}
302
308
303
309
/// Converts `content` into a vector of tokens that are valid input for this model.
@@ -364,7 +370,13 @@ impl LlamaModel {
364
370
token. 0
365
371
) ;
366
372
367
- unsafe { CStr :: from_ptr ( llama_token_get_text ( * * self . model . try_read ( ) . unwrap ( ) , token. 0 ) ) } . to_bytes ( )
373
+ unsafe {
374
+ CStr :: from_ptr ( llama_token_get_text (
375
+ * * self . model . try_read ( ) . unwrap ( ) ,
376
+ token. 0 ,
377
+ ) )
378
+ }
379
+ . to_bytes ( )
368
380
}
369
381
370
382
/// Creates a new evaluation context for this model.
@@ -384,7 +396,7 @@ impl LlamaModel {
384
396
let ctx = unsafe {
385
397
// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
386
398
// for at least the lifetime of `LlamaContext`.
387
- llama_new_context_with_model ( * * self . model . blocking_read ( ) , params)
399
+ llama_new_context_with_model ( * * self . model . try_read ( ) . unwrap ( ) , params)
388
400
} ;
389
401
390
402
let cpus = num_cpus:: get ( ) as u32 ;
@@ -396,13 +408,14 @@ impl LlamaModel {
396
408
}
397
409
398
410
LlamaSession {
399
- model : self . clone ( ) ,
400
- inner : Arc :: new ( Mutex :: new ( LlamaContextInner { ptr : ctx } ) ) ,
401
- history_size : 0 ,
411
+ inner : Arc :: new ( LlamaSessionInner {
412
+ model : self . clone ( ) ,
413
+ ctx : Mutex :: new ( LlamaContextInner { ptr : ctx } ) ,
414
+ history_size : AtomicUsize :: new ( 0 ) ,
415
+ } ) ,
402
416
}
403
417
}
404
418
405
-
406
419
/// Returns the beginning of sentence (BOS) token for this context.
407
420
pub fn bos ( & self ) -> Token {
408
421
self . bos_token
@@ -448,6 +461,7 @@ struct LlamaContextInner {
448
461
}
449
462
450
463
unsafe impl Send for LlamaContextInner { }
464
+
451
465
unsafe impl Sync for LlamaContextInner { }
452
466
453
467
impl Drop for LlamaContextInner {
@@ -464,15 +478,21 @@ impl Drop for LlamaContextInner {
464
478
///
465
479
/// This stores a small amount of state, which is destroyed when the session is dropped.
466
480
/// You can create an arbitrary number of sessions for a model using [`LlamaModel::create_session`].
481
+ #[ derive( Clone ) ]
467
482
pub struct LlamaSession {
483
+ inner : Arc < LlamaSessionInner > ,
484
+ }
485
+
486
+ /// The cloned part of a [`LlamaSession`].
487
+ struct LlamaSessionInner {
468
488
/// The model this session was created from.
469
489
model : LlamaModel ,
470
490
471
491
/// A pointer to the llama.cpp side of the model context.
472
- inner : Arc < Mutex < LlamaContextInner > > ,
492
+ ctx : Mutex < LlamaContextInner > ,
473
493
474
494
/// The number of tokens present in this model's context.
475
- history_size : usize ,
495
+ history_size : AtomicUsize ,
476
496
}
477
497
478
498
/// An error raised while advancing the context in a [`LlamaSession`].
@@ -508,7 +528,7 @@ impl LlamaSession {
508
528
///
509
529
/// The model will generate new tokens from the end of the context.
510
530
pub fn advance_context_with_tokens (
511
- & mut self ,
531
+ & self ,
512
532
tokens : impl AsRef < [ Token ] > ,
513
533
) -> Result < ( ) , LlamaContextError > {
514
534
let tokens = tokens. as_ref ( ) ;
@@ -562,7 +582,7 @@ impl LlamaSession {
562
582
if unsafe {
563
583
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
564
584
// initialized above.
565
- llama_decode ( self . inner . blocking_lock ( ) . ptr , batch)
585
+ llama_decode ( self . inner . ctx . blocking_lock ( ) . ptr , batch)
566
586
} != 0
567
587
{
568
588
return Err ( LlamaInternalError . into ( ) ) ;
@@ -577,40 +597,78 @@ impl LlamaSession {
577
597
llama_batch_free ( batch)
578
598
} ;
579
599
580
- self . history_size += tokens. len ( ) ;
600
+ self . inner
601
+ . history_size
602
+ . fetch_add ( n_tokens, Ordering :: SeqCst ) ;
581
603
582
604
Ok ( ( ) )
583
605
}
584
606
607
+ /// Advances the inner context of this model with `tokens`.
608
+ ///
609
+ /// This is a thin `tokio::spawn_blocking` wrapper around
610
+ /// [`LlamaSession::advance_context_with_tokens`].
611
+ pub async fn advance_context_with_tokens_async (
612
+ & mut self ,
613
+ tokens : impl AsRef < [ Token ] > ,
614
+ ) -> Result < ( ) , LlamaContextError > {
615
+ let tokens = tokens. as_ref ( ) . to_owned ( ) ;
616
+ let session = self . clone ( ) ;
617
+
618
+ tokio:: task:: spawn_blocking ( move || session. advance_context_with_tokens ( tokens) )
619
+ . await
620
+ . unwrap ( )
621
+ }
622
+
585
623
/// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
586
624
///
587
625
/// `ctx` is typically a UTF-8 string, but anything that can be downcast to bytes is accepted.
588
626
pub fn advance_context ( & mut self , ctx : impl AsRef < [ u8 ] > ) -> Result < ( ) , LlamaContextError > {
589
- let tokens = self . model . tokenize_bytes ( ctx. as_ref ( ) ) ?. into_boxed_slice ( ) ;
627
+ let tokens = self
628
+ . inner
629
+ . model
630
+ . tokenize_bytes ( ctx. as_ref ( ) ) ?
631
+ . into_boxed_slice ( ) ;
590
632
591
633
self . advance_context_with_tokens ( tokens)
592
634
}
593
635
636
+ /// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
637
+ ///
638
+ /// This is a thin `tokio::spawn_blocking` wrapper around
639
+ /// [`LlamaSession::advance_context`].
640
+ pub async fn advance_context_async (
641
+ & self ,
642
+ ctx : impl AsRef < [ u8 ] > ,
643
+ ) -> Result < ( ) , LlamaContextError > {
644
+ let ctx = ctx. as_ref ( ) . to_owned ( ) ;
645
+ let session = self . clone ( ) ;
646
+
647
+ tokio:: task:: spawn_blocking ( move || {
648
+ let tokens = session. inner . model . tokenize_bytes ( ctx) ?. into_boxed_slice ( ) ;
649
+
650
+ session. advance_context_with_tokens ( tokens)
651
+ } )
652
+ . await
653
+ . unwrap ( )
654
+ }
655
+
594
656
/// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
595
657
/// This is where you want to be if you just want some completions.
596
658
pub fn start_completing ( & mut self ) -> CompletionHandle {
597
659
let ( tx, rx) = flume:: unbounded ( ) ;
660
+ let history_size = self . inner . history_size . load ( Ordering :: SeqCst ) ;
661
+ let session = self . clone ( ) ;
598
662
599
- info ! (
600
- "Generating completions with {} tokens of history" ,
601
- self . history_size,
602
- ) ;
603
-
604
- let past_tokens = self . history_size ;
605
- let mutex = self . inner . clone ( ) ;
663
+ info ! ( "Generating completions with {history_size} tokens of history" ) ;
606
664
607
665
thread:: spawn ( move || unsafe {
608
666
llama_beam_search (
609
- mutex . blocking_lock ( ) . ptr ,
667
+ session . inner . ctx . blocking_lock ( ) . ptr ,
610
668
Some ( detail:: llama_beam_search_callback) ,
611
669
Box :: leak ( Box :: new ( detail:: BeamSearchState { tx } ) ) as * mut _ as * mut c_void ,
612
670
1 ,
613
- past_tokens as i32 ,
671
+ history_size as i32 ,
614
672
32_768 ,
615
673
) ;
616
674
} ) ;
@@ -620,7 +678,7 @@ impl LlamaSession {
620
678
621
679
/// Returns the model this session was created from.
622
680
pub fn model ( & self ) -> LlamaModel {
623
- self . model . clone ( )
681
+ self . inner . model . clone ( )
624
682
}
625
683
}
626
684
@@ -634,9 +692,11 @@ pub struct CompletionToken<'a> {
634
692
}
635
693
636
694
impl < ' a > CompletionToken < ' a > {
637
- /// Decodes this token, returning the bytes composing it.
638
- pub fn as_bytes ( & self ) -> & [ u8 ] {
639
- self . ctx . model . detokenize ( self . token )
695
+ /// Decodes this token, returning the bytes it is composed of.
696
+ pub fn detokenize ( & self ) -> TinyVec < [ u8 ; 8 ] > {
697
+ let model = self . ctx . model ( ) ;
698
+
699
+ model. detokenize ( self . token ) . into ( )
640
700
}
641
701
642
702
/// Returns this token as an `i32`.
@@ -735,7 +795,7 @@ mod detail {
735
795
// SAFETY: beam_views[i] exists where 0 <= i <= n_beams.
736
796
* beam_state. beam_views . add ( i)
737
797
}
738
- . eob = true ;
798
+ . eob = true ;
739
799
}
740
800
}
741
801
0 commit comments