forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed relevance score calculation (langchain4j#164)
- Loading branch information
1 parent
f2bb6f9
commit b804d03
Showing
8 changed files
with
131 additions
and
56 deletions.
There are no files selected for viewing
62 changes: 62 additions & 0 deletions
62
langchain4j-core/src/main/java/dev/langchain4j/store/embedding/CosineSimilarity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package dev.langchain4j.store.embedding; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
|
||
import static dev.langchain4j.internal.Exceptions.illegalArgument; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; | ||
|
||
public class CosineSimilarity { | ||
|
||
/** | ||
* Calculates cosine similarity between two vectors. | ||
* <p> | ||
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity. | ||
* It produces a value in the range: | ||
* <p> | ||
* -1 indicates vectors are diametrically opposed (opposite directions). | ||
* <p> | ||
* 0 indicates vectors are orthogonal (no directional similarity). | ||
* <p> | ||
* 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude). | ||
* <p> | ||
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are. | ||
* | ||
* @param embeddingA first embedding vector | ||
* @param embeddingB second embedding vector | ||
* @return cosine similarity in the range [-1..1] | ||
*/ | ||
public static double between(Embedding embeddingA, Embedding embeddingB) { | ||
ensureNotNull(embeddingA, "embeddingA"); | ||
ensureNotNull(embeddingB, "embeddingB"); | ||
|
||
float[] vectorA = embeddingA.vector(); | ||
float[] vectorB = embeddingB.vector(); | ||
|
||
if (vectorA.length != vectorB.length) { | ||
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)", | ||
vectorA.length, vectorB.length); | ||
} | ||
|
||
double dotProduct = 0.0; | ||
double normA = 0.0; | ||
double normB = 0.0; | ||
|
||
for (int i = 0; i < vectorA.length; i++) { | ||
dotProduct += vectorA[i] * vectorB[i]; | ||
normA += vectorA[i] * vectorA[i]; | ||
normB += vectorB[i] * vectorB[i]; | ||
} | ||
|
||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); | ||
} | ||
|
||
/** | ||
* Converts relevance score into cosine similarity. | ||
* | ||
* @param relevanceScore Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant. | ||
* @return Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant. | ||
*/ | ||
public static double fromRelevanceScore(double relevanceScore) { | ||
return relevanceScore * 2 - 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 0 additions & 43 deletions
43
langchain4j-core/src/main/java/dev/langchain4j/store/embedding/Similarity.java
This file was deleted.
Oops, something went wrong.
26 changes: 26 additions & 0 deletions
26
langchain4j-core/src/test/java/dev/langchain4j/store/embedding/CosineSimilarityTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package dev.langchain4j.store.embedding; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.data.Percentage.withPercentage; | ||
|
||
class CosineSimilarityTest { | ||
|
||
@Test | ||
void should_calculate_cosine_similarity() { | ||
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1}); | ||
Embedding embeddingB = Embedding.from(new float[]{-1, -1, -1}); | ||
|
||
assertThat(CosineSimilarity.between(embeddingA, embeddingA)).isCloseTo(1, withPercentage(1)); | ||
assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(-1, withPercentage(1)); | ||
} | ||
|
||
@Test | ||
void should_convert_relevance_score_into_cosine_similarity() { | ||
assertThat(CosineSimilarity.fromRelevanceScore(0)).isEqualTo(-1); | ||
assertThat(CosineSimilarity.fromRelevanceScore(0.5)).isEqualTo(0); | ||
assertThat(CosineSimilarity.fromRelevanceScore(1)).isEqualTo(1); | ||
} | ||
} |
15 changes: 15 additions & 0 deletions
15
langchain4j-core/src/test/java/dev/langchain4j/store/embedding/RelevanceScoreTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.store.embedding; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
class RelevanceScoreTest { | ||
|
||
@Test | ||
void should_convert_cosine_similarity_into_relevance_score() { | ||
assertThat(RelevanceScore.fromCosineSimilarity(-1)).isEqualTo(0); | ||
assertThat(RelevanceScore.fromCosineSimilarity(0)).isEqualTo(0.5); | ||
assertThat(RelevanceScore.fromCosineSimilarity(1)).isEqualTo(1); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters