Skip to content

Commit a7deb5d

Browse files
authored
feat: Added encode_single function (#18)
1 parent 6c3b471 commit a7deb5d

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ The compiled binary installed via `cargo install` is significantly faster (often
107107

108108
* **Encode a single sentence:**
109109
```shell
110-
model2vec-rs encode "Hello world" "minishlab/potion-base-8M"
110+
model2vec-rs encode-single "Hello world" "minishlab/potion-base-8M"
111111
```
112112
Embeddings will be printed to the console in JSON format. This command should take less than 0.1s to execute.
113113

src/main.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,21 @@ enum Commands {
2626
#[arg(short, long)]
2727
output: Option<String>,
2828
},
29+
/// Encode a single sentence
30+
EncodeSingle {
31+
/// The sentence to embed
32+
sentence: String,
33+
/// HF repo ID or local dir
34+
model: String,
35+
#[arg(short, long)]
36+
output: Option<String>,
37+
},
2938
}
3039

3140
fn main() -> Result<()> {
3241
let cli = Cli::parse();
3342
match cli.cmd {
43+
// Encode multiple sentences from a file or input string
3444
Commands::Encode { input, model, output } => {
3545
let texts = if Path::new(&input).exists() {
3646
std::fs::read_to_string(&input)?
@@ -52,6 +62,19 @@ fn main() -> Result<()> {
5262
println!("{:?}", embs);
5363
}
5464
}
65+
// Encode a single sentence
66+
Commands::EncodeSingle { sentence, model, output } => {
67+
let m = StaticModel::from_pretrained(&model, None, None, None)?;
68+
let embedding = m.encode_single(&sentence);
69+
70+
if let Some(path) = output {
71+
let file = File::create(path).context("creating output file failed")?;
72+
serde_json::to_writer(BufWriter::new(file), &embedding)
73+
.context("writing JSON failed")?;
74+
} else {
75+
println!("{embedding:#?}");
76+
}
77+
}
5578
}
5679
Ok(())
5780
}

src/model.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ impl StaticModel {
197197
self.encode_with_args(sentences, Some(512), 1024)
198198
}
199199

200+
// / Encode a single sentence into a vector
201+
pub fn encode_single(&self, sentence: &str) -> Vec<f32> {
202+
self.encode(&[sentence.to_string()]).into_iter().next().unwrap_or_default()
203+
}
204+
200205
/// Mean-pool a single token-ID list into a vector
201206
fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
202207
let mut sum = vec![0.0; self.embeddings.ncols()];

tests/test_model.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,35 @@ fn test_encode_empty_sentence() {
7272
assert!(vec.iter().all(|&x| x == 0.0), "All entries should be zero");
7373
}
7474

75+
/// Test that encoding a single sentence returns the correct shape
76+
#[test]
77+
fn test_encode_single() {
78+
let model = load_test_model();
79+
let sentence = "hello world";
80+
81+
// Single-sentence helper → 1-D
82+
let one_d = model.encode_single(sentence);
83+
84+
// Batch call with a 1-element slice → 2-D wrapper
85+
let two_d = model.encode(&[sentence.to_string()]);
86+
87+
// Shape assertions
88+
assert!(
89+
!one_d.is_empty(),
90+
"encode_single must return a non-empty 1-D vector"
91+
);
92+
assert_eq!(
93+
two_d.len(),
94+
1,
95+
"encode(&[..]) should wrap the result in a Vec with length 1"
96+
);
97+
assert_eq!(
98+
two_d[0].len(),
99+
one_d.len(),
100+
"inner vector dimensionality should match encode_single output"
101+
);
102+
}
103+
75104
/// Test override of `normalize` flag in from_pretrained
76105
#[test]
77106
fn test_normalization_flag_override() {

0 commit comments

Comments
 (0)