Skip to content

Commit 9e8a559

Browse files
authored
feat: Added optional arguments for encode and from_pretrained (#3)
1 parent 22cce1b commit 9e8a559

File tree

9 files changed

+227
-99
lines changed

9 files changed

+227
-99
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ version = "0.1.0"
44
edition = "2021"
55
description = "Fast State-of-the-Art Static Embeddings in Rust"
66
readme = "README.md"
7-
license = "MIT"
87
license-file = "LICENSE"
98
authors = ["Thomas van Dongen <thomas123@live.nl>", "Stéphan Tulkens <stephantul@gmail.com>"]
109
homepage = "https://github.com/MinishLab/model2vec-rs"
@@ -21,6 +20,7 @@ clap = { version = "4.0", features = ["derive"] }
2120
anyhow = "1.0"
2221
serde_json = "1.0"
2322
half = "2.0"
23+
rayon = "1.7"
2424

2525
[dev-dependencies]
2626
approx = "0.5"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use model2vec_rust::inference::StaticModel;
2222

2323
fn main() -> Result<()> {
2424
// Load a model from the Hugging Face Hub or a local path
25-
let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None)?;
25+
let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None, None, None)?;
2626

2727
// Prepare a list of sentences
2828
let texts = vec![

src/main.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ fn main() -> Result<()> {
4646
vec![input]
4747
};
4848

49-
let m = StaticModel::from_pretrained(&model, None)?;
49+
let m = StaticModel::from_pretrained(&model, None, None, None)?;
5050
let embs = m.encode(&texts);
5151

5252
if let Some(path) = output {
@@ -67,8 +67,9 @@ fn main() -> Result<()> {
6767
vec![input]
6868
};
6969

70-
let m = StaticModel::from_pretrained(&model, None)?;
71-
let ids = m.tokenize(&texts);
70+
let m = StaticModel::from_pretrained(&model, None, None, None)?;
71+
// Provide default None for max_tokens to include all tokens
72+
let ids = m.tokenize(&texts, None);
7273
println!("Token ID sequences: {:#?}", ids);
7374
}
7475
}

src/model.rs

Lines changed: 124 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use tokenizers::Tokenizer;
33
use safetensors::{SafeTensors, tensor::Dtype};
44
use half::f16;
55
use ndarray::Array2;
6-
use std::fs::read;
7-
use std::path::Path;
6+
use rayon::prelude::*;
7+
use std::{fs::read, path::Path, env};
88
use anyhow::{Result, Context, anyhow};
99
use serde_json::Value;
1010

@@ -13,104 +13,168 @@ pub struct StaticModel {
1313
tokenizer: Tokenizer,
1414
embeddings: Array2<f32>,
1515
normalize: bool,
16+
median_token_length: usize,
1617
}
1718

1819
impl StaticModel {
1920
/// Load a Model2Vec model from a local folder or the HF Hub.
20-
///
21-
/// # Arguments
22-
/// * `repo_or_path` - HF repo ID or local filesystem path
23-
/// * `subfolder` - optional subdirectory inside the repo or folder
24-
pub fn from_pretrained(repo_or_path: &str, subfolder: Option<&str>) -> Result<Self> {
21+
pub fn from_pretrained(
22+
repo_or_path: &str,
23+
token: Option<&str>,
24+
normalize: Option<bool>,
25+
subfolder: Option<&str>,
26+
) -> Result<Self> {
27+
// If provided, set HF token for authenticated downloads
28+
if let Some(tok) = token {
29+
env::set_var("HF_HUB_TOKEN", tok);
30+
}
31+
2532
// Determine file paths
2633
let (tok_path, mdl_path, cfg_path) = {
2734
let base = Path::new(repo_or_path);
2835
if base.exists() {
29-
// Local path
3036
let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
3137
let t = folder.join("tokenizer.json");
3238
let m = folder.join("model.safetensors");
3339
let c = folder.join("config.json");
3440
if !t.exists() || !m.exists() || !c.exists() {
35-
return Err(anyhow!("Local path {:?} missing tokenizer/model/config files", folder));
41+
return Err(anyhow!("Local path {:?} missing files", folder));
3642
}
3743
(t, m, c)
3844
} else {
39-
// HF Hub path
40-
let api = Api::new().context("Failed to initialize HF Hub API")?;
45+
let api = Api::new().context("HF Hub API init failed")?;
4146
let repo = api.model(repo_or_path.to_string());
47+
// note: token not used with sync Api
4248
let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
43-
let t = repo.get(&format!("{}tokenizer.json", prefix)).context("Failed to download tokenizer.json")?;
44-
let m = repo.get(&format!("{}model.safetensors", prefix)).context("Failed to download model.safetensors")?;
45-
let c = repo.get(&format!("{}config.json", prefix)).context("Failed to download config.json")?;
49+
let t = repo.get(&format!("{}tokenizer.json", prefix))
50+
.context("Download tokenizer.json failed")?;
51+
let m = repo.get(&format!("{}model.safetensors", prefix))
52+
.context("Download model.safetensors failed")?;
53+
let c = repo.get(&format!("{}config.json", prefix))
54+
.context("Download config.json failed")?;
4655
(t.into(), m.into(), c.into())
4756
}
4857
};
4958

5059
// Load tokenizer
5160
let tokenizer = Tokenizer::from_file(&tok_path)
52-
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
61+
.map_err(|e| anyhow!("Tokenizer load error: {}", e))?;
62+
63+
// Median token length for char-level truncation
64+
let mut lengths: Vec<usize> = tokenizer.get_vocab(false)
65+
.keys().map(|tk| tk.len()).collect();
66+
lengths.sort_unstable();
67+
let median_token_length = *lengths.get(lengths.len() / 2).unwrap_or(&1);
68+
69+
// Read config.json for default normalize
70+
let cfg: Value = serde_json::from_slice(&read(&cfg_path)?)
71+
.context("Parse config.json failed")?;
72+
let config_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
73+
let normalize = normalize.unwrap_or(config_norm);
5374

54-
// Read safetensors file
55-
let bytes = read(&mdl_path).context("Failed to read model.safetensors")?;
56-
let safet = SafeTensors::deserialize(&bytes).context("Failed to parse safetensors")?;
57-
let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0")).context("Embedding tensor not found")?;
75+
// Read safetensors
76+
let bytes = read(&mdl_path).context("Read safetensors failed")?;
77+
let safet = SafeTensors::deserialize(&bytes).context("Parse safetensors failed")?;
78+
let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0"))
79+
.context("No 'embeddings' tensor")?;
5880
let shape = (tensor.shape()[0] as usize, tensor.shape()[1] as usize);
5981
let raw = tensor.data();
6082
let dtype = tensor.dtype();
6183

62-
// Read config.json for normalization flag
63-
let cfg_bytes = read(&cfg_path).context("Failed to read config.json")?;
64-
let cfg: Value = serde_json::from_slice(&cfg_bytes).context("Failed to parse config.json")?;
65-
let normalize = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
66-
67-
// Decode raw bytes into Vec<f32> based on dtype
84+
// Decode raw data to f32
6885
let floats: Vec<f32> = match dtype {
6986
Dtype::F32 => raw.chunks_exact(4)
70-
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
71-
.collect(),
87+
.map(|b| f32::from_le_bytes([b[0],b[1],b[2],b[3]])).collect(),
7288
Dtype::F16 => raw.chunks_exact(2)
73-
.map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32())
74-
.collect(),
75-
Dtype::I8 => raw.iter()
76-
.map(|&b| (b as i8) as f32)
77-
.collect(),
78-
other => return Err(anyhow!("Unsupported tensor dtype: {:?}", other)),
89+
.map(|b| f16::from_le_bytes([b[0],b[1]]).to_f32()).collect(),
90+
Dtype::I8 => raw.iter().map(|&b| b as i8 as f32).collect(),
91+
other => return Err(anyhow!("Unsupported dtype: {:?}", other)),
7992
};
93+
let embeddings = Array2::from_shape_vec(shape, floats)
94+
.context("Array shape error")?;
8095

81-
// Construct ndarray
82-
let embeddings = Array2::from_shape_vec(shape, floats).context("Failed to create embeddings array")?;
83-
84-
Ok(Self { tokenizer, embeddings, normalize })
96+
Ok(Self { tokenizer, embeddings, normalize, median_token_length })
8597
}
8698

87-
/// Tokenize input texts into token ID sequences
88-
pub fn tokenize(&self, texts: &[String]) -> Vec<Vec<u32>> {
89-
texts.iter().map(|text| {
90-
let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed");
91-
enc.get_ids().to_vec()
99+
/// Tokenize input texts into token ID sequences with optional truncation.
100+
pub fn tokenize(&self, texts: &[String], max_length: Option<usize>) -> Vec<Vec<u32>> {
101+
let prepared: Vec<String> = texts.iter().map(|t| {
102+
if let Some(max) = max_length {
103+
t.chars().take(max.saturating_mul(self.median_token_length)).collect()
104+
} else { t.clone() }
105+
}).collect();
106+
let encs = self.tokenizer.encode_batch(prepared, false).expect("Tokenization failed");
107+
encs.into_iter().map(|enc| {
108+
let mut ids = enc.get_ids().to_vec(); if let Some(max) = max_length { ids.truncate(max); } ids
92109
}).collect()
93110
}
94111

95-
/// Encode texts into embeddings via mean-pooling and optional L2-normalization
96-
pub fn encode(&self, texts: &[String]) -> Vec<Vec<f32>> {
97-
texts.iter().map(|text| {
98-
let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed");
99-
let ids = enc.get_ids();
100-
let mut sum = vec![0.0f32; self.embeddings.ncols()];
101-
for &id in ids {
102-
let row = self.embeddings.row(id as usize);
103-
for (i, &v) in row.iter().enumerate() {
104-
sum[i] += v;
112+
/// Encode texts into embeddings.
113+
///
114+
/// # Arguments
115+
/// * `texts` - slice of input strings
116+
/// * `show_progress` - whether to print batch progress
117+
/// * `max_length` - max tokens per text (truncation)
118+
/// * `batch_size` - number of texts per batch
119+
/// * `use_parallel` - use Rayon parallelism
120+
/// * `parallel_threshold` - minimum texts to enable parallelism
121+
pub fn encode_with_args(
122+
&self,
123+
texts: &[String],
124+
show_progress: bool,
125+
max_length: Option<usize>,
126+
batch_size: usize,
127+
use_multiprocessing: bool,
128+
multiprocessing_threshold: usize,
129+
) -> Vec<Vec<f32>> {
130+
let total = texts.len();
131+
let num_batches = (total + batch_size - 1) / batch_size;
132+
let iter = texts.chunks(batch_size);
133+
134+
if use_multiprocessing && total > multiprocessing_threshold {
135+
// disable tokenizer internal parallel
136+
env::set_var("TOKENIZERS_PARALLELISM", "false");
137+
iter
138+
.enumerate()
139+
.flat_map(|(b, chunk)| {
140+
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
141+
self.tokenize(chunk, max_length)
142+
.into_par_iter()
143+
.map(|ids| self.pool_ids(ids))
144+
.collect::<Vec<_>>()
145+
})
146+
.collect()
147+
} else {
148+
let mut out = Vec::with_capacity(total);
149+
for (b, chunk) in iter.enumerate() {
150+
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
151+
for ids in self.tokenize(chunk, max_length) {
152+
out.push(self.pool_ids(ids));
105153
}
106154
}
107-
let count = ids.len().max(1) as f32;
108-
sum.iter_mut().for_each(|v| *v /= count);
109-
if self.normalize {
110-
let norm = sum.iter().map(|&x| x * x).sum::<f32>().sqrt().max(1e-12);
111-
sum.iter_mut().for_each(|v| *v /= norm);
112-
}
113-
sum
114-
}).collect()
155+
out
156+
}
157+
}
158+
159+
/// Default encode: no progress, max_length=512, batch_size=1024, no parallel.
160+
pub fn encode(&self, texts: &[String]) -> Vec<Vec<f32>> {
161+
self.encode_with_args(texts, false, Some(512), 1024, true, 10_000)
162+
}
163+
164+
/// Mean-pool one ID list to embedding
165+
fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
166+
let mut sum = vec![0.0; self.embeddings.ncols()];
167+
for &id in &ids {
168+
let row = self.embeddings.row(id as usize);
169+
for (i, &v) in row.iter().enumerate() { sum[i] += v; }
170+
}
171+
let cnt = ids.len().max(1) as f32;
172+
sum.iter_mut().for_each(|v| *v /= cnt);
173+
if self.normalize {
174+
let norm = sum.iter().map(|&x| x*x).sum::<f32>().sqrt().max(1e-12);
175+
sum.iter_mut().for_each(|v| *v /= norm);
176+
}
177+
sum
115178
}
116179
}
180+

tests/common.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use model2vec_rs::model::StaticModel;
2+
3+
/// Load the small float32 test model from fixtures
4+
pub fn load_test_model() -> StaticModel {
5+
StaticModel::from_pretrained(
6+
"tests/fixtures/test-model-float32",
7+
None, // token
8+
None, // normalize
9+
None, // subfolder
10+
)
11+
.expect("Failed to load test model")
12+
}

tests/test_encode.rs

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/test_load.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use approx::assert_relative_eq;
22
use model2vec_rs::model::StaticModel;
33

4-
fn encode_hello(path: &str) -> Vec<f32> {
4+
fn encode_with_model(path: &str) -> Vec<f32> {
55
// Helper function to load the model and encode "hello world"
6-
let model = StaticModel::from_pretrained(path, None)
7-
.expect(&format!("Failed to load model at {}", path));
6+
let model = StaticModel::from_pretrained(
7+
path,
8+
None,
9+
None,
10+
None,
11+
).expect(&format!("Failed to load model at {}", path));
12+
813
let out = model.encode(&["hello world".to_string()]);
914
assert_eq!(out.len(), 1);
1015
out.into_iter().next().unwrap()
@@ -14,11 +19,11 @@ fn encode_hello(path: &str) -> Vec<f32> {
1419
fn quantized_models_match_float32() {
1520
// Compare quantized models against the float32 model
1621
let base = "tests/fixtures/test-model-float32";
17-
let ref_emb = encode_hello(base);
22+
let ref_emb = encode_with_model(base);
1823

1924
for quant in &["float16", "int8"] {
2025
let path = format!("tests/fixtures/test-model-{}", quant);
21-
let emb = encode_hello(&path);
26+
let emb = encode_with_model(&path);
2227

2328
assert_eq!(emb.len(), ref_emb.len());
2429

0 commit comments

Comments
 (0)