@@ -3,8 +3,8 @@ use tokenizers::Tokenizer;
33use safetensors:: { SafeTensors , tensor:: Dtype } ;
44use half:: f16;
55use ndarray:: Array2 ;
6- use std :: fs :: read ;
7- use std:: path:: Path ;
6+ use rayon :: prelude :: * ;
7+ use std:: { fs :: read , path:: Path , env } ;
88use anyhow:: { Result , Context , anyhow} ;
99use serde_json:: Value ;
1010
@@ -13,104 +13,168 @@ pub struct StaticModel {
1313 tokenizer : Tokenizer ,
1414 embeddings : Array2 < f32 > ,
1515 normalize : bool ,
16+ median_token_length : usize ,
1617}
1718
1819impl StaticModel {
1920 /// Load a Model2Vec model from a local folder or the HF Hub.
20- ///
21- /// # Arguments
22- /// * `repo_or_path` - HF repo ID or local filesystem path
23- /// * `subfolder` - optional subdirectory inside the repo or folder
24- pub fn from_pretrained ( repo_or_path : & str , subfolder : Option < & str > ) -> Result < Self > {
21+ pub fn from_pretrained (
22+ repo_or_path : & str ,
23+ token : Option < & str > ,
24+ normalize : Option < bool > ,
25+ subfolder : Option < & str > ,
26+ ) -> Result < Self > {
27+ // If provided, set HF token for authenticated downloads
28+ if let Some ( tok) = token {
29+ env:: set_var ( "HF_HUB_TOKEN" , tok) ;
30+ }
31+
2532 // Determine file paths
2633 let ( tok_path, mdl_path, cfg_path) = {
2734 let base = Path :: new ( repo_or_path) ;
2835 if base. exists ( ) {
29- // Local path
3036 let folder = subfolder. map ( |s| base. join ( s) ) . unwrap_or_else ( || base. to_path_buf ( ) ) ;
3137 let t = folder. join ( "tokenizer.json" ) ;
3238 let m = folder. join ( "model.safetensors" ) ;
3339 let c = folder. join ( "config.json" ) ;
3440 if !t. exists ( ) || !m. exists ( ) || !c. exists ( ) {
35- return Err ( anyhow ! ( "Local path {:?} missing tokenizer/model/config files" , folder) ) ;
41+ return Err ( anyhow ! ( "Local path {:?} missing files" , folder) ) ;
3642 }
3743 ( t, m, c)
3844 } else {
39- // HF Hub path
40- let api = Api :: new ( ) . context ( "Failed to initialize HF Hub API" ) ?;
45+ let api = Api :: new ( ) . context ( "HF Hub API init failed" ) ?;
4146 let repo = api. model ( repo_or_path. to_string ( ) ) ;
47+ // note: token not used with sync Api
4248 let prefix = subfolder. map ( |s| format ! ( "{}/" , s) ) . unwrap_or_default ( ) ;
43- let t = repo. get ( & format ! ( "{}tokenizer.json" , prefix) ) . context ( "Failed to download tokenizer.json" ) ?;
44- let m = repo. get ( & format ! ( "{}model.safetensors" , prefix) ) . context ( "Failed to download model.safetensors" ) ?;
45- let c = repo. get ( & format ! ( "{}config.json" , prefix) ) . context ( "Failed to download config.json" ) ?;
49+ let t = repo. get ( & format ! ( "{}tokenizer.json" , prefix) )
50+ . context ( "Download tokenizer.json failed" ) ?;
51+ let m = repo. get ( & format ! ( "{}model.safetensors" , prefix) )
52+ . context ( "Download model.safetensors failed" ) ?;
53+ let c = repo. get ( & format ! ( "{}config.json" , prefix) )
54+ . context ( "Download config.json failed" ) ?;
4655 ( t. into ( ) , m. into ( ) , c. into ( ) )
4756 }
4857 } ;
4958
5059 // Load tokenizer
5160 let tokenizer = Tokenizer :: from_file ( & tok_path)
52- . map_err ( |e| anyhow ! ( "Failed to load tokenizer: {}" , e) ) ?;
61+ . map_err ( |e| anyhow ! ( "Tokenizer load error: {}" , e) ) ?;
62+
63+ // Median token length for char-level truncation
64+ let mut lengths: Vec < usize > = tokenizer. get_vocab ( false )
65+ . keys ( ) . map ( |tk| tk. len ( ) ) . collect ( ) ;
66+ lengths. sort_unstable ( ) ;
67+ let median_token_length = * lengths. get ( lengths. len ( ) / 2 ) . unwrap_or ( & 1 ) ;
68+
69+ // Read config.json for default normalize
70+ let cfg: Value = serde_json:: from_slice ( & read ( & cfg_path) ?)
71+ . context ( "Parse config.json failed" ) ?;
72+ let config_norm = cfg. get ( "normalize" ) . and_then ( Value :: as_bool) . unwrap_or ( true ) ;
73+ let normalize = normalize. unwrap_or ( config_norm) ;
5374
54- // Read safetensors file
55- let bytes = read ( & mdl_path) . context ( "Failed to read model.safetensors" ) ?;
56- let safet = SafeTensors :: deserialize ( & bytes) . context ( "Failed to parse safetensors" ) ?;
57- let tensor = safet. tensor ( "embeddings" ) . or_else ( |_| safet. tensor ( "0" ) ) . context ( "Embedding tensor not found" ) ?;
75+ // Read safetensors
76+ let bytes = read ( & mdl_path) . context ( "Read safetensors failed" ) ?;
77+ let safet = SafeTensors :: deserialize ( & bytes) . context ( "Parse safetensors failed" ) ?;
78+ let tensor = safet. tensor ( "embeddings" ) . or_else ( |_| safet. tensor ( "0" ) )
79+ . context ( "No 'embeddings' tensor" ) ?;
5880 let shape = ( tensor. shape ( ) [ 0 ] as usize , tensor. shape ( ) [ 1 ] as usize ) ;
5981 let raw = tensor. data ( ) ;
6082 let dtype = tensor. dtype ( ) ;
6183
62- // Read config.json for normalization flag
63- let cfg_bytes = read ( & cfg_path) . context ( "Failed to read config.json" ) ?;
64- let cfg: Value = serde_json:: from_slice ( & cfg_bytes) . context ( "Failed to parse config.json" ) ?;
65- let normalize = cfg. get ( "normalize" ) . and_then ( Value :: as_bool) . unwrap_or ( true ) ;
66-
67- // Decode raw bytes into Vec<f32> based on dtype
84+ // Decode raw data to f32
6885 let floats: Vec < f32 > = match dtype {
6986 Dtype :: F32 => raw. chunks_exact ( 4 )
70- . map ( |b| f32:: from_le_bytes ( [ b[ 0 ] , b[ 1 ] , b[ 2 ] , b[ 3 ] ] ) )
71- . collect ( ) ,
87+ . map ( |b| f32:: from_le_bytes ( [ b[ 0 ] , b[ 1 ] , b[ 2 ] , b[ 3 ] ] ) ) . collect ( ) ,
7288 Dtype :: F16 => raw. chunks_exact ( 2 )
73- . map ( |b| f16:: from_le_bytes ( [ b[ 0 ] , b[ 1 ] ] ) . to_f32 ( ) )
74- . collect ( ) ,
75- Dtype :: I8 => raw. iter ( )
76- . map ( |& b| ( b as i8 ) as f32 )
77- . collect ( ) ,
78- other => return Err ( anyhow ! ( "Unsupported tensor dtype: {:?}" , other) ) ,
89+ . map ( |b| f16:: from_le_bytes ( [ b[ 0 ] , b[ 1 ] ] ) . to_f32 ( ) ) . collect ( ) ,
90+ Dtype :: I8 => raw. iter ( ) . map ( |& b| b as i8 as f32 ) . collect ( ) ,
91+ other => return Err ( anyhow ! ( "Unsupported dtype: {:?}" , other) ) ,
7992 } ;
93+ let embeddings = Array2 :: from_shape_vec ( shape, floats)
94+ . context ( "Array shape error" ) ?;
8095
81- // Construct ndarray
82- let embeddings = Array2 :: from_shape_vec ( shape, floats) . context ( "Failed to create embeddings array" ) ?;
83-
84- Ok ( Self { tokenizer, embeddings, normalize } )
96+ Ok ( Self { tokenizer, embeddings, normalize, median_token_length } )
8597 }
8698
87- /// Tokenize input texts into token ID sequences
88- pub fn tokenize ( & self , texts : & [ String ] ) -> Vec < Vec < u32 > > {
89- texts. iter ( ) . map ( |text| {
90- let enc = self . tokenizer . encode ( text. as_str ( ) , false ) . expect ( "Tokenization failed" ) ;
91- enc. get_ids ( ) . to_vec ( )
99+ /// Tokenize input texts into token ID sequences with optional truncation.
100+ pub fn tokenize ( & self , texts : & [ String ] , max_length : Option < usize > ) -> Vec < Vec < u32 > > {
101+ let prepared: Vec < String > = texts. iter ( ) . map ( |t| {
102+ if let Some ( max) = max_length {
103+ t. chars ( ) . take ( max. saturating_mul ( self . median_token_length ) ) . collect ( )
104+ } else { t. clone ( ) }
105+ } ) . collect ( ) ;
106+ let encs = self . tokenizer . encode_batch ( prepared, false ) . expect ( "Tokenization failed" ) ;
107+ encs. into_iter ( ) . map ( |enc| {
108+ let mut ids = enc. get_ids ( ) . to_vec ( ) ; if let Some ( max) = max_length { ids. truncate ( max) ; } ids
92109 } ) . collect ( )
93110 }
94111
95- /// Encode texts into embeddings via mean-pooling and optional L2-normalization
96- pub fn encode ( & self , texts : & [ String ] ) -> Vec < Vec < f32 > > {
97- texts. iter ( ) . map ( |text| {
98- let enc = self . tokenizer . encode ( text. as_str ( ) , false ) . expect ( "Tokenization failed" ) ;
99- let ids = enc. get_ids ( ) ;
100- let mut sum = vec ! [ 0.0f32 ; self . embeddings. ncols( ) ] ;
101- for & id in ids {
102- let row = self . embeddings . row ( id as usize ) ;
103- for ( i, & v) in row. iter ( ) . enumerate ( ) {
104- sum[ i] += v;
112+ /// Encode texts into embeddings.
113+ ///
114+ /// # Arguments
115+ /// * `texts` - slice of input strings
116+ /// * `show_progress` - whether to print batch progress
117+ /// * `max_length` - max tokens per text (truncation)
118+ /// * `batch_size` - number of texts per batch
119+ /// * `use_parallel` - use Rayon parallelism
120+ /// * `parallel_threshold` - minimum texts to enable parallelism
121+ pub fn encode_with_args (
122+ & self ,
123+ texts : & [ String ] ,
124+ show_progress : bool ,
125+ max_length : Option < usize > ,
126+ batch_size : usize ,
127+ use_multiprocessing : bool ,
128+ multiprocessing_threshold : usize ,
129+ ) -> Vec < Vec < f32 > > {
130+ let total = texts. len ( ) ;
131+ let num_batches = ( total + batch_size - 1 ) / batch_size;
132+ let iter = texts. chunks ( batch_size) ;
133+
134+ if use_multiprocessing && total > multiprocessing_threshold {
135+ // disable tokenizer internal parallel
136+ env:: set_var ( "TOKENIZERS_PARALLELISM" , "false" ) ;
137+ iter
138+ . enumerate ( )
139+ . flat_map ( |( b, chunk) | {
140+ if show_progress { eprintln ! ( "Batch {}/{}" , b+1 , num_batches) ; }
141+ self . tokenize ( chunk, max_length)
142+ . into_par_iter ( )
143+ . map ( |ids| self . pool_ids ( ids) )
144+ . collect :: < Vec < _ > > ( )
145+ } )
146+ . collect ( )
147+ } else {
148+ let mut out = Vec :: with_capacity ( total) ;
149+ for ( b, chunk) in iter. enumerate ( ) {
150+ if show_progress { eprintln ! ( "Batch {}/{}" , b+1 , num_batches) ; }
151+ for ids in self . tokenize ( chunk, max_length) {
152+ out. push ( self . pool_ids ( ids) ) ;
105153 }
106154 }
107- let count = ids. len ( ) . max ( 1 ) as f32 ;
108- sum. iter_mut ( ) . for_each ( |v| * v /= count) ;
109- if self . normalize {
110- let norm = sum. iter ( ) . map ( |& x| x * x) . sum :: < f32 > ( ) . sqrt ( ) . max ( 1e-12 ) ;
111- sum. iter_mut ( ) . for_each ( |v| * v /= norm) ;
112- }
113- sum
114- } ) . collect ( )
155+ out
156+ }
157+ }
158+
159+ /// Default encode: no progress, max_length=512, batch_size=1024, no parallel.
160+ pub fn encode ( & self , texts : & [ String ] ) -> Vec < Vec < f32 > > {
161+ self . encode_with_args ( texts, false , Some ( 512 ) , 1024 , true , 10_000 )
162+ }
163+
164+ /// Mean-pool one ID list to embedding
165+ fn pool_ids ( & self , ids : Vec < u32 > ) -> Vec < f32 > {
166+ let mut sum = vec ! [ 0.0 ; self . embeddings. ncols( ) ] ;
167+ for & id in & ids {
168+ let row = self . embeddings . row ( id as usize ) ;
169+ for ( i, & v) in row. iter ( ) . enumerate ( ) { sum[ i] += v; }
170+ }
171+ let cnt = ids. len ( ) . max ( 1 ) as f32 ;
172+ sum. iter_mut ( ) . for_each ( |v| * v /= cnt) ;
173+ if self . normalize {
174+ let norm = sum. iter ( ) . map ( |& x| x* x) . sum :: < f32 > ( ) . sqrt ( ) . max ( 1e-12 ) ;
175+ sum. iter_mut ( ) . for_each ( |v| * v /= norm) ;
176+ }
177+ sum
115178 }
116179}
180+
0 commit comments