Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ fn main() -> Result<()> {
let embs = m.encode(&texts);

if let Some(output) = output {
let file = File::create(&output).context("Failed to create output file")?;
let file = File::create(&output).context("failed to create output file")?;
let writer = BufWriter::new(file);
serde_json::to_writer(writer, &embs).context("Failed to write embeddings to JSON")?;
serde_json::to_writer(writer, &embs).context("failed to write embeddings to JSON")?;
} else {
println!("Embeddings: {:#?}", embs);
println!("{:?}", embs);
}
}
}
Expand Down
104 changes: 50 additions & 54 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ pub struct StaticModel {
}

impl StaticModel {
/// Load a Model2Vec model from a local folder or the HF Hub.
pub fn from_pretrained(
repo_or_path: &str,
/// Load a Model2Vec model from a local folder or the HuggingFace Hub.
///
/// # Arguments
/// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
/// * `token` - Optional HuggingFace token for authenticated downloads.
/// * `normalize` - Optional flag to normalize embeddings (default from config.json).
/// * `subfolder` - Optional subfolder within the repo or path to look for model files.
pub fn from_pretrained<P: AsRef<Path>>(
repo_or_path: P,
token: Option<&str>,
normalize: Option<bool>,
subfolder: Option<&str>,
Expand All @@ -31,36 +37,29 @@ impl StaticModel {

// Locate tokenizer.json, model.safetensors, config.json
let (tok_path, mdl_path, cfg_path) = {
let base = Path::new(repo_or_path);
let base = repo_or_path.as_ref();
if base.exists() {
let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
let t = folder.join("tokenizer.json");
let m = folder.join("model.safetensors");
let c = folder.join("config.json");
if !t.exists() || !m.exists() || !c.exists() {
return Err(anyhow!("Local path {:?} missing tokenizer/model/config", folder));
return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
}
(t, m, c)
} else {
let api = Api::new().context("HF Hub API init failed")?;
let repo = api.model(repo_or_path.to_string());
let api = Api::new().context("hf-hub API init failed")?;
let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
let t = repo
.get(&format!("{}tokenizer.json", prefix))
.context("Failed to download tokenizer.json")?;
let m = repo
.get(&format!("{}model.safetensors", prefix))
.context("Failed to download model.safetensors")?;
let c = repo
.get(&format!("{}config.json", prefix))
.context("Failed to download config.json")?;
let t = repo.get(&format!("{prefix}tokenizer.json"))?;
let m = repo.get(&format!("{prefix}model.safetensors"))?;
let c = repo.get(&format!("{prefix}config.json"))?;
(t.into(), m.into(), c.into())
}
};

// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer
Expand All @@ -69,71 +68,70 @@ impl StaticModel {
.map(|tk| tk.len())
.collect();
lens.sort_unstable();
let median_token_length = *lens.get(lens.len() / 2).unwrap_or(&1);
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Read normalize default from config.json
let cfg_bytes = fs::read(&cfg_path).context("Failed to read config.json")?;
let cfg: Value = serde_json::from_slice(&cfg_bytes).context("Failed to parse config.json")?;
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Serialize the tokenizer to JSON, then parse it and get the unk_token
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("Failed to serialize tokenizer to JSON: {}", e))?;
let spec: Value = serde_json::from_str(&spec_json)
.context("Failed to parse tokenizer JSON spec")?;
let spec_json = tokenizer.to_string(false).map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str)
.unwrap_or("[UNK]");
let unk_token_id_val = tokenizer
let unk_token_id = tokenizer
.token_to_id(unk_token)
.ok_or_else(|| anyhow!("Tokenizer JSON declared unk_token=\"{}\" but it’s not in the vocab", unk_token))?
.ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
as usize;
let unk_token_id = Some(unk_token_id_val);

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("Failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("Failed to parse safetensors")?;
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
let tensor = safet
.tensor("embeddings")
.or_else(|_| safet.tensor("0"))
.context("No 'embeddings' tensor found")?;
let (rows, cols) = (tensor.shape()[0] as usize, tensor.shape()[1] as usize);
let raw = tensor.data();
.context("embeddings tensor not found")?;

let [rows, cols]: [usize; 2] = tensor
.shape()
.try_into()
.context("embedding tensor is not 2‑D")?;
let raw = tensor.data();
let dtype = tensor.dtype();

// Decode into f32
let floats: Vec<f32> = match dtype {
Dtype::F32 => raw.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0],b[1],b[2],b[3]])).collect(),
Dtype::F16 => raw.chunks_exact(2)
.map(|b| f16::from_le_bytes([b[0],b[1]]).to_f32()).collect(),
Dtype::I8 => raw.iter().map(|&b| b as i8 as f32).collect(),
other => return Err(anyhow!("Unsupported tensor dtype: {:?}", other)),
Dtype::F32 => raw
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect(),
Dtype::F16 => raw
.chunks_exact(2)
.map(|b| f16::from_le_bytes(b.try_into().unwrap()).to_f32())
.collect(),
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats)
.context("Failed to build embeddings array")?;
.context("failed to build embeddings array")?;

Ok(Self {
tokenizer,
embeddings,
normalize,
median_token_length,
unk_token_id,
unk_token_id: Some(unk_token_id),
})
}

/// Char-level truncation to max_tokens * median_token_length
fn truncate_str<'a>(s: &'a str, max_tokens: usize, median_len: usize) -> &'a str {
fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
let max_chars = max_tokens.saturating_mul(median_len);
// if <= max_chars characters, return whole string
if s.chars().count() <= max_chars {
return s;
}
// otherwise find the byte index of the (max_chars)th char and cut there
match s.char_indices().nth(max_chars) {
Some((byte_idx, _)) => &s[..byte_idx],
None => s,
Expand All @@ -160,11 +158,9 @@ impl StaticModel {
let truncated: Vec<&str> = batch
.iter()
.map(|text| {
if let Some(max_tok) = max_length {
Self::truncate_str(text, max_tok, self.median_token_length)
} else {
text.as_str()
}
max_length
.map(|max_tok| Self::truncate_str(text, max_tok, self.median_token_length))
.unwrap_or(text.as_str())
})
.collect();

Expand All @@ -176,7 +172,7 @@ impl StaticModel {
truncated.into_iter().map(Into::into).collect(),
/* add_special_tokens = */ false,
)
.expect("Tokenization failed");
.expect("tokenization failed");

// Pool each token-ID list into a single mean vector
for encoding in encodings {
Expand Down