Skip to content

Commit

Permalink
Merge pull request #35 from adenyes/master
Browse files Browse the repository at this point in the history
[fix] Compute embedding cosine distance
  • Loading branch information
rellfy authored Jul 15, 2024
2 parents 4201f4f + 4776cce commit 0dcd0e0
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,20 @@ impl Embedding {
Ok(embeddings.data.swap_remove(0))
}

pub fn magnitude(&self) -> f64 {
self.vec.iter().map(|x| x * x).sum::<f64>().sqrt()
}

pub fn distance(&self, other: &Self) -> f64 {
let dot_product: f64 = self
.vec
.iter()
.zip(other.vec.iter())
.map(|(x, y)| x * y)
.sum();
let product_of_lengths = (self.vec.len() * other.vec.len()) as f64;
let product_of_magnitudes = self.magnitude() * other.magnitude();

dot_product / product_of_lengths
1.0 - dot_product / product_of_magnitudes
}
}

Expand Down Expand Up @@ -145,8 +149,7 @@ mod tests {
total_tokens: 0,
},
};

assert_eq!(embeddings.distances()[0], 0.0);
assert_eq!(embeddings.distances()[0], 1.0);
}

#[test]
Expand All @@ -167,6 +170,6 @@ mod tests {
},
};

assert_ne!(embeddings.distances()[0], 0.0);
assert_eq!(embeddings.distances()[0], 0.29289321881345254);
}
}

0 comments on commit 0dcd0e0

Please sign in to comment.