Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ target/
#.idea/

local
.DS_Store
.DS_Store
data
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ clap = { version = "4.0", features = ["derive"] }
anyhow = "1.0"
serde_json = "1.0"
half = "2.0"
rayon = "1.7"

[dev-dependencies]
approx = "0.5"
serde_json = "1.0"
serde_json = "1.0"
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ Make embeddings with custom encode args:
```rust
let embeddings = model.encode_with_args(
&texts, // input texts
false, // show progress
Some(512), // max length
1204, // batch size
true, // use multiprocessing
10_000, // multiprocessing threshold
1024, // batch size
);
```

Expand Down
38 changes: 9 additions & 29 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use clap::{Parser, Subcommand};
use anyhow::Result;
use anyhow::{Context, Result};
use std::path::Path;
use std::fs::File;
use std::io::BufWriter;

mod model;
use model::StaticModel;
Expand All @@ -24,13 +26,6 @@ enum Commands {
#[arg(short, long)]
output: Option<String>,
},
/// Show token ID sequences for input texts
Tokens {
/// Input text or path to file
input: String,
/// Hugging Face repo ID or local path
model: String,
},
}

fn main() -> Result<()> {
Expand All @@ -48,30 +43,15 @@ fn main() -> Result<()> {

let m = StaticModel::from_pretrained(&model, None, None, None)?;
let embs = m.encode(&texts);

if let Some(path) = output {
let json = serde_json::to_string(&embs)?;
std::fs::write(path, json)?;

if let Some(output) = output {
let file = File::create(&output).context("Failed to create output file")?;
let writer = BufWriter::new(file);
serde_json::to_writer(writer, &embs).context("Failed to write embeddings to JSON")?;
} else {
println!("{:#?}", embs);
println!("Embeddings: {:#?}", embs);
}
}

Commands::Tokens { input, model } => {
let texts = if Path::new(&input).exists() {
std::fs::read_to_string(&input)?
.lines()
.map(str::to_string)
.collect()
} else {
vec![input]
};

let m = StaticModel::from_pretrained(&model, None, None, None)?;
// Provide default None for max_tokens to include all tokens
let ids = m.tokenize(&texts, None);
println!("Token ID sequences: {:#?}", ids);
}
}
Ok(())
}
185 changes: 98 additions & 87 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use tokenizers::Tokenizer;
use safetensors::{SafeTensors, tensor::Dtype};
use half::f16;
use ndarray::Array2;
use rayon::prelude::*;
use std::{fs::read, path::Path, env};
use anyhow::{Result, Context, anyhow};
use std::{env, fs, path::Path};
use anyhow::{Context, Result, anyhow};
use serde_json::Value;

/// Static embedding model for Model2Vec
Expand All @@ -29,7 +28,7 @@ impl StaticModel {
env::set_var("HF_HUB_TOKEN", tok);
}

// Determine file paths
// Locate tokenizer.json, model.safetensors, config.json
let (tok_path, mdl_path, cfg_path) = {
let base = Path::new(repo_or_path);
if base.exists() {
Expand All @@ -38,143 +37,155 @@ impl StaticModel {
let m = folder.join("model.safetensors");
let c = folder.join("config.json");
if !t.exists() || !m.exists() || !c.exists() {
return Err(anyhow!("Local path {:?} missing files", folder));
return Err(anyhow!("Local path {:?} missing tokenizer/model/config", folder));
}
(t, m, c)
} else {
let api = Api::new().context("HF Hub API init failed")?;
let repo = api.model(repo_or_path.to_string());
// note: token not used with sync Api
let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
let t = repo.get(&format!("{}tokenizer.json", prefix))
.context("Download tokenizer.json failed")?;
let m = repo.get(&format!("{}model.safetensors", prefix))
.context("Download model.safetensors failed")?;
let c = repo.get(&format!("{}config.json", prefix))
.context("Download config.json failed")?;
let t = repo
.get(&format!("{}tokenizer.json", prefix))
.context("Failed to download tokenizer.json")?;
let m = repo
.get(&format!("{}model.safetensors", prefix))
.context("Failed to download model.safetensors")?;
let c = repo
.get(&format!("{}config.json", prefix))
.context("Failed to download config.json")?;
(t.into(), m.into(), c.into())
}
};

// Load tokenizer
// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| anyhow!("Tokenizer load error: {}", e))?;

// Median token length for char-level truncation
let mut lengths: Vec<usize> = tokenizer.get_vocab(false)
.keys().map(|tk| tk.len()).collect();
lengths.sort_unstable();
let median_token_length = *lengths.get(lengths.len() / 2).unwrap_or(&1);

// Read config.json for default normalize
let cfg: Value = serde_json::from_slice(&read(&cfg_path)?)
.context("Parse config.json failed")?;
let config_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(config_norm);

// Read safetensors
let bytes = read(&mdl_path).context("Read safetensors failed")?;
let safet = SafeTensors::deserialize(&bytes).context("Parse safetensors failed")?;
let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0"))
.context("No 'embeddings' tensor")?;
let shape = (tensor.shape()[0] as usize, tensor.shape()[1] as usize);
let raw = tensor.data();
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer
.get_vocab(false)
.keys()
.map(|tk| tk.len())
.collect();
lens.sort_unstable();
let median_token_length = *lens.get(lens.len() / 2).unwrap_or(&1);

// Read normalize default from config.json
let cfg_bytes = fs::read(&cfg_path).context("Failed to read config.json")?;
let cfg: Value = serde_json::from_slice(&cfg_bytes).context("Failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("Failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("Failed to parse safetensors")?;
let tensor = safet
.tensor("embeddings")
.or_else(|_| safet.tensor("0"))
.context("No 'embeddings' tensor found")?;
let (rows, cols) = (tensor.shape()[0] as usize, tensor.shape()[1] as usize);
let raw = tensor.data();
let dtype = tensor.dtype();

// Decode raw data to f32
// Decode into f32
let floats: Vec<f32> = match dtype {
Dtype::F32 => raw.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0],b[1],b[2],b[3]])).collect(),
Dtype::F16 => raw.chunks_exact(2)
.map(|b| f16::from_le_bytes([b[0],b[1]]).to_f32()).collect(),
Dtype::I8 => raw.iter().map(|&b| b as i8 as f32).collect(),
other => return Err(anyhow!("Unsupported dtype: {:?}", other)),
other => return Err(anyhow!("Unsupported tensor dtype: {:?}", other)),
};
let embeddings = Array2::from_shape_vec(shape, floats)
.context("Array shape error")?;
let embeddings = Array2::from_shape_vec((rows, cols), floats)
.context("Failed to build embeddings array")?;

Ok(Self { tokenizer, embeddings, normalize, median_token_length })
Ok(Self {
tokenizer,
embeddings,
normalize,
median_token_length,
})
}

/// Tokenize input texts into token ID sequences with optional truncation.
pub fn tokenize(&self, texts: &[String], max_length: Option<usize>) -> Vec<Vec<u32>> {
let prepared: Vec<String> = texts.iter().map(|t| {
if let Some(max) = max_length {
t.chars().take(max.saturating_mul(self.median_token_length)).collect()
} else { t.clone() }
}).collect();
let encs = self.tokenizer.encode_batch(prepared, false).expect("Tokenization failed");
encs.into_iter().map(|enc| {
let mut ids = enc.get_ids().to_vec(); if let Some(max) = max_length { ids.truncate(max); } ids
}).collect()
/// Char-level truncation to max_tokens * median_token_length
fn truncate_str<'a>(s: &'a str, max_tokens: usize, median_len: usize) -> &'a str {
let max_chars = max_tokens.saturating_mul(median_len);
// if <= max_chars characters, return whole string
if s.chars().count() <= max_chars {
return s;
}
// otherwise find the byte index of the (max_chars)th char and cut there
match s.char_indices().nth(max_chars) {
Some((byte_idx, _)) => &s[..byte_idx],
None => s,
}
}

/// Encode texts into embeddings.
///
/// # Arguments
/// * `texts` - slice of input strings
/// * `show_progress` - whether to print batch progress
/// * `max_length` - max tokens per text (truncation)
/// * `batch_size` - number of texts per batch
/// * `use_parallel` - use Rayon parallelism
/// * `parallel_threshold` - minimum texts to enable parallelism
pub fn encode_with_args(
&self,
texts: &[String],
show_progress: bool,
max_length: Option<usize>,
batch_size: usize,
use_multiprocessing: bool,
multiprocessing_threshold: usize,
) -> Vec<Vec<f32>> {
let total = texts.len();
let num_batches = (total + batch_size - 1) / batch_size;
let iter = texts.chunks(batch_size);

if use_multiprocessing && total > multiprocessing_threshold {
// disable tokenizer internal parallel
env::set_var("TOKENIZERS_PARALLELISM", "false");
iter
.enumerate()
.flat_map(|(b, chunk)| {
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
self.tokenize(chunk, max_length)
.into_par_iter()
.map(|ids| self.pool_ids(ids))
.collect::<Vec<_>>()
let mut out = Vec::with_capacity(texts.len());

for chunk in texts.chunks(batch_size) {
// Truncate the input strings to max_length * median_token_length
let slices: Vec<&str> = chunk.iter()
.map(|t| {
if let Some(mx) = max_length {
Self::truncate_str(t, mx, self.median_token_length)
} else {
t.as_str()
}
})
.collect()
} else {
let mut out = Vec::with_capacity(total);
for (b, chunk) in iter.enumerate() {
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
for ids in self.tokenize(chunk, max_length) {
out.push(self.pool_ids(ids));
.collect();

// Tokenize the batch
let encs = self
.tokenizer
.encode_batch(slices, false)
.expect("Tokenization failed");

// Encode the token IDs into embeddings
for enc in encs {
let mut ids = enc.get_ids().to_vec();
if let Some(mx) = max_length {
ids.truncate(mx);
}
out.push(self.pool_ids(ids));
}
out
}

out
}

/// Default encode: no progress, max_length=512, batch_size=1024, no parallel.
/// Default encode: `max_length=512`, `batch_size=1024`
pub fn encode(&self, texts: &[String]) -> Vec<Vec<f32>> {
self.encode_with_args(texts, false, Some(512), 1024, true, 10_000)
self.encode_with_args(texts, Some(512), 1024)
}

/// Mean-pool one ID list to embedding
/// Mean-pool a single token-ID list into a vector
fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
let mut sum = vec![0.0; self.embeddings.ncols()];
for &id in &ids {
let row = self.embeddings.row(id as usize);
for (i, &v) in row.iter().enumerate() { sum[i] += v; }
for (i, &v) in row.iter().enumerate() {
sum[i] += v;
}
}
let cnt = ids.len().max(1) as f32;
sum.iter_mut().for_each(|v| *v /= cnt);
sum.iter_mut().for_each(|x| *x /= cnt);
if self.normalize {
let norm = sum.iter().map(|&x| x*x).sum::<f32>().sqrt().max(1e-12);
sum.iter_mut().for_each(|v| *v /= norm);
let norm = sum.iter().map(|&v| v*v).sum::<f32>().sqrt().max(1e-12);
sum.iter_mut().for_each(|x| *x /= norm);
}
sum
}
}

1 change: 1 addition & 0 deletions tests/fixtures/embeddings_long.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[-0.7980846762657166, -0.6025453805923462]]
File renamed without changes.
Loading