Skip to content

Commit

Permalink
Fixed relevance score calculation (langchain4j#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j authored Sep 7, 2023
1 parent f2bb6f9 commit b804d03
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;

import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

public class CosineSimilarity {

/**
* Calculates cosine similarity between two vectors.
* <p>
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity.
* It produces a value in the range:
* <p>
* -1 indicates vectors are diametrically opposed (opposite directions).
* <p>
* 0 indicates vectors are orthogonal (no directional similarity).
* <p>
* 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
* <p>
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
*
* @param embeddingA first embedding vector
* @param embeddingB second embedding vector
* @return cosine similarity in the range [-1..1]
*/
public static double between(Embedding embeddingA, Embedding embeddingB) {
ensureNotNull(embeddingA, "embeddingA");
ensureNotNull(embeddingB, "embeddingB");

float[] vectorA = embeddingA.vector();
float[] vectorB = embeddingB.vector();

if (vectorA.length != vectorB.length) {
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)",
vectorA.length, vectorB.length);
}

double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;

for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += vectorA[i] * vectorA[i];
normB += vectorB[i] * vectorB[i];
}

return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

/**
* Converts relevance score into cosine similarity.
*
* @param relevanceScore Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
* @return Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
*/
public static double fromRelevanceScore(double relevanceScore) {
return relevanceScore * 2 - 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
public class RelevanceScore {

/**
* Calculates the relevance score between two vectors using cosine similarity.
* Converts cosine similarity into relevance score.
*
* @param a first vector
* @param b second vector
* @return score in the range [0, 1], where 0 indicates no relevance and 1 indicates full relevance
* @param cosineSimilarity Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
* @return Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
*/
public static double cosine(float[] a, float[] b) {
double cosineSimilarity = Similarity.cosine(a, b);
return 1 - (1 - cosineSimilarity) / 2;
public static double fromCosineSimilarity(double cosineSimilarity) {
return (cosineSimilarity + 1) / 2;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;

class CosineSimilarityTest {

@Test
void should_calculate_cosine_similarity() {
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1});
Embedding embeddingB = Embedding.from(new float[]{-1, -1, -1});

assertThat(CosineSimilarity.between(embeddingA, embeddingA)).isCloseTo(1, withPercentage(1));
assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(-1, withPercentage(1));
}

@Test
void should_convert_relevance_score_into_cosine_similarity() {
assertThat(CosineSimilarity.fromRelevanceScore(0)).isEqualTo(-1);
assertThat(CosineSimilarity.fromRelevanceScore(0.5)).isEqualTo(0);
assertThat(CosineSimilarity.fromRelevanceScore(1)).isEqualTo(1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.store.embedding;

import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

class RelevanceScoreTest {

@Test
void should_convert_cosine_similarity_into_relevance_score() {
assertThat(RelevanceScore.fromCosineSimilarity(-1)).isEqualTo(0);
assertThat(RelevanceScore.fromCosineSimilarity(0)).isEqualTo(0.5);
assertThat(RelevanceScore.fromCosineSimilarity(1)).isEqualTo(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.protobuf.Value;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
Expand Down Expand Up @@ -176,9 +177,10 @@ private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Vector vector, Embed
.get(METADATA_TEXT_SEGMENT);

Embedding embedding = Embedding.from(vector.getValuesList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);

return new EmbeddingMatch<>(
RelevanceScore.cosine(embedding.vector(), referenceEmbedding.vector()),
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
vector.getId(),
embedding,
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.RelevanceScore;

import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.ValidationUtils.*;
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -111,7 +118,8 @@ public List<E> classify(String text) {
double meanScore = 0;
double maxScore = 0;
for (Embedding exampleEmbedding : exampleEmbeddings) {
double score = RelevanceScore.cosine(textEmbedding.vector(), exampleEmbedding.vector());
double cosineSimilarity = CosineSimilarity.between(textEmbedding, exampleEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
meanScore += score;
maxScore = Math.max(score, maxScore);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.gson.reflect.TypeToken;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
Expand All @@ -13,7 +14,12 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;

import static dev.langchain4j.internal.Utils.randomUUID;
import static java.nio.file.StandardOpenOption.CREATE;
Expand Down Expand Up @@ -116,7 +122,8 @@ public List<EmbeddingMatch<Embedded>> findRelevant(Embedding referenceEmbedding,
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);

for (Entry<Embedded> entry : entries) {
double score = RelevanceScore.cosine(entry.embedding.vector(), referenceEmbedding.vector());
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > maxResults) {
Expand Down

0 comments on commit b804d03

Please sign in to comment.