File tree Expand file tree Collapse file tree 4 files changed +58
-1
lines changed Expand file tree Collapse file tree 4 files changed +58
-1
lines changed Original file line number Diff line number Diff line change @@ -107,7 +107,7 @@ The compiled binary installed via `cargo install` is significantly faster (often
107107
108108* ** Encode a single sentence:**
109109 ``` shell
110- model2vec-rs encode " Hello world" " minishlab/potion-base-8M"
110+ model2vec-rs encode-single " Hello world" " minishlab/potion-base-8M"
111111 ```
112112 Embeddings will be printed to the console in JSON format. This command should take less than 0.1s to execute.
113113
Original file line number Diff line number Diff line change @@ -26,11 +26,21 @@ enum Commands {
2626 #[ arg( short, long) ]
2727 output : Option < String > ,
2828 } ,
29+ /// Encode a single sentence
30+ EncodeSingle {
31+ /// The sentence to embed
32+ sentence : String ,
33+ /// HF repo ID or local dir
34+ model : String ,
35+ #[ arg( short, long) ]
36+ output : Option < String > ,
37+ } ,
2938}
3039
3140fn main ( ) -> Result < ( ) > {
3241 let cli = Cli :: parse ( ) ;
3342 match cli. cmd {
43+ // Encode multiple sentences from a file or input string
3444 Commands :: Encode { input, model, output } => {
3545 let texts = if Path :: new ( & input) . exists ( ) {
3646 std:: fs:: read_to_string ( & input) ?
@@ -52,6 +62,19 @@ fn main() -> Result<()> {
5262 println ! ( "{:?}" , embs) ;
5363 }
5464 }
65+ // Encode a single sentence
66+ Commands :: EncodeSingle { sentence, model, output } => {
67+ let m = StaticModel :: from_pretrained ( & model, None , None , None ) ?;
68+ let embedding = m. encode_single ( & sentence) ;
69+
70+ if let Some ( path) = output {
71+ let file = File :: create ( path) . context ( "creating output file failed" ) ?;
72+ serde_json:: to_writer ( BufWriter :: new ( file) , & embedding)
73+ . context ( "writing JSON failed" ) ?;
74+ } else {
75+ println ! ( "{embedding:#?}" ) ;
76+ }
77+ }
5578 }
5679 Ok ( ( ) )
5780}
Original file line number Diff line number Diff line change @@ -197,6 +197,11 @@ impl StaticModel {
197197 self . encode_with_args ( sentences, Some ( 512 ) , 1024 )
198198 }
199199
200+ // / Encode a single sentence into a vector
201+ pub fn encode_single ( & self , sentence : & str ) -> Vec < f32 > {
202+ self . encode ( & [ sentence. to_string ( ) ] ) . into_iter ( ) . next ( ) . unwrap_or_default ( )
203+ }
204+
200205 /// Mean-pool a single token-ID list into a vector
201206 fn pool_ids ( & self , ids : Vec < u32 > ) -> Vec < f32 > {
202207 let mut sum = vec ! [ 0.0 ; self . embeddings. ncols( ) ] ;
Original file line number Diff line number Diff line change @@ -72,6 +72,35 @@ fn test_encode_empty_sentence() {
7272 assert ! ( vec. iter( ) . all( |& x| x == 0.0 ) , "All entries should be zero" ) ;
7373}
7474
75+ /// Test that encoding a single sentence returns the correct shape
76+ #[ test]
77+ fn test_encode_single ( ) {
78+ let model = load_test_model ( ) ;
79+ let sentence = "hello world" ;
80+
81+ // Single-sentence helper → 1-D
82+ let one_d = model. encode_single ( sentence) ;
83+
84+ // Batch call with a 1-element slice → 2-D wrapper
85+ let two_d = model. encode ( & [ sentence. to_string ( ) ] ) ;
86+
87+ // Shape assertions
88+ assert ! (
89+ !one_d. is_empty( ) ,
90+ "encode_single must return a non-empty 1-D vector"
91+ ) ;
92+ assert_eq ! (
93+ two_d. len( ) ,
94+ 1 ,
95+ "encode(&[..]) should wrap the result in a Vec with length 1"
96+ ) ;
97+ assert_eq ! (
98+ two_d[ 0 ] . len( ) ,
99+ one_d. len( ) ,
100+ "inner vector dimensionality should match encode_single output"
101+ ) ;
102+ }
103+
75104/// Test override of `normalize` flag in from_pretrained
76105#[ test]
77106fn test_normalization_flag_override ( ) {
You can’t perform that action at this time.
0 commit comments