@@ -551,10 +551,12 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
551551pub fn deduplicate_with_mmr (
552552 query_embedding : & [ f32 ] ,
553553 embeddings : & [ & [ f32 ] ] ,
554+ languages : & [ & str ] ,
554555 lambda : f32 ,
555556 k : usize ,
556557) -> Vec < usize > {
557558 let mut idxs = vec ! [ ] ;
559+ let mut lang_counts = HashMap :: new ( ) ;
558560
559561 if embeddings. len ( ) < k {
560562 return ( 0 ..embeddings. len ( ) ) . collect ( ) ;
@@ -576,14 +578,20 @@ pub fn deduplicate_with_mmr(
576578 second_part = cos_sim;
577579 }
578580 }
579- let equation_score = lambda * first_part - ( 1. - lambda) * second_part;
581+ let mut equation_score = lambda * first_part - ( 1. - lambda) * second_part;
582+
583+ // score is MMR + (1/4)^n where n is the number of times a language has been selected
584+ let count = lang_counts. get ( languages[ i] ) . unwrap_or ( & 0 ) ;
585+ equation_score += 0.25_f32 . powi ( * count) ;
586+
580587 if equation_score > best_score {
581588 best_score = equation_score;
582589 idx_to_add = Some ( i) ;
583590 }
584591 }
585592 if let Some ( i) = idx_to_add {
586593 idxs. push ( i) ;
594+ * lang_counts. entry ( languages[ i] ) . or_insert ( 0 ) += 1 ;
587595 }
588596 }
589597 idxs
@@ -601,7 +609,11 @@ pub fn deduplicate_snippets(
601609 . iter ( )
602610 . map ( |s| s. embedding . as_deref ( ) . unwrap ( ) )
603611 . collect :: < Vec < _ > > ( ) ;
604- deduplicate_with_mmr ( & query_embedding, & embeddings, lambda, k)
612+ let languages = all_snippets
613+ . iter ( )
614+ . map ( |s| s. lang . as_ref ( ) )
615+ . collect :: < Vec < _ > > ( ) ;
616+ deduplicate_with_mmr ( & query_embedding, & embeddings, & languages, lambda, k)
605617 } ;
606618
607619 info ! ( "preserved idxs after MMR are {:?}" , idxs) ;
0 commit comments