@@ -30,6 +30,7 @@ pub use schema::{Embedding, Payload};
3030
3131const COLLECTION_NAME : & str = "documents" ;
3232const SCORE_THRESHOLD : f32 = 0.3 ;
33+ const EMBEDDING_DIM : usize = 384 ;
3334
3435#[ derive( Error , Debug ) ]
3536pub enum SemanticError {
@@ -149,7 +150,7 @@ fn collection_config() -> CreateCollection {
149150 collection_name : COLLECTION_NAME . to_string ( ) ,
150151 vectors_config : Some ( VectorsConfig {
151152 config : Some ( vectors_config:: Config :: Params ( VectorParams {
152- size : 384 ,
153+ size : EMBEDDING_DIM as u64 ,
153154 distance : Distance :: Cosine . into ( ) ,
154155 ..Default :: default ( )
155156 } ) ) ,
@@ -396,8 +397,9 @@ impl Semantic {
396397 . collect :: < Vec < _ > > ( )
397398 } ) ?;
398399
399- // TODO: Deduplicate with respect to all vectors
400- let target_vector = vectors. first ( ) . unwrap ( ) . clone ( ) ;
400+ // deduplicate with mmr with respect to the mean of query vectors
401+ // TODO: implement a more robust multi-vector deduplication strategy
402+ let target_vector = mean_pool ( vectors) ;
401403 Ok ( deduplicate_snippets ( results, target_vector, limit) )
402404 }
403405
@@ -629,6 +631,19 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
629631 dot ( a, b) / ( norm ( a) * norm ( b) )
630632}
631633
634+ // Calculate the element-wise mean of the embeddings
635+ fn mean_pool ( embeddings : Vec < Vec < f32 > > ) -> Vec < f32 > {
636+ let len = embeddings. len ( ) as f32 ;
637+ let mut result = vec ! [ 0.0 ; EMBEDDING_DIM ] ;
638+ for embedding in embeddings {
639+ for ( i, v) in embedding. iter ( ) . enumerate ( ) {
640+ result[ i] += v;
641+ }
642+ }
643+ result. iter_mut ( ) . for_each ( |v| * v /= len) ;
644+ result
645+ }
646+
632647// returns a list of indices to preserve from `snippets`
633648//
634649// query_embedding: the embedding of the query terms
0 commit comments