Skip to content

Commit

Permalink
reducing duplication of *EmbeddingStoreIT
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-learning-dynamo committed Nov 18, 2023
1 parent 9897d65 commit 16f60db
Show file tree
Hide file tree
Showing 21 changed files with 274 additions and 227 deletions.
1 change: 0 additions & 1 deletion langchain4j-bedrock/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
Expand Down
1 change: 0 additions & 1 deletion langchain4j-cassandra/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${parent.version}</version>
</dependency>

<dependency>
Expand Down
9 changes: 8 additions & 1 deletion langchain4j-chroma/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
Expand All @@ -39,6 +38,14 @@
<artifactId>okhttp</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
package dev.langchain4j.store.embedding.chroma;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.AbstractEmbeddingStoreIT;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.List;

import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;

@Disabled("needs Chroma running locally")
class ChromaEmbeddingStoreIT {
class ChromaEmbeddingStoreIT extends AbstractEmbeddingStoreIT {

/**
* First ensure you have Chroma running locally. If not, then:
Expand All @@ -36,199 +26,13 @@ class ChromaEmbeddingStoreIT {

private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();

@Test
void should_add_embedding() {

Embedding embedding = embeddingModel.embed(randomUUID()).content();

String id = embeddingStore.add(embedding);
assertThat(id).isNotNull();

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);

EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}

@Test
void should_add_embedding_with_id() {

String id = randomUUID();
Embedding embedding = embeddingModel.embed(randomUUID()).content();

embeddingStore.add(id, embedding);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);

EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isNull();
}

@Test
void should_add_embedding_with_segment() {

TextSegment segment = TextSegment.from(randomUUID());
Embedding embedding = embeddingModel.embed(segment.text()).content();

String id = embeddingStore.add(embedding, segment);
assertThat(id).isNotNull();

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);

EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}

@Test
void should_add_embedding_with_segment_with_metadata() {

TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
Embedding embedding = embeddingModel.embed(segment.text()).content();

String id = embeddingStore.add(embedding, segment);
assertThat(id).isNotNull();

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);

EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
assertThat(match.embedded()).isEqualTo(segment);
}

@Test
void should_add_multiple_embeddings() {

Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();

List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
assertThat(ids).hasSize(2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);

EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
assertThat(firstMatch.embedded()).isNull();

EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isBetween(0d, 1d);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isNull();
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}

@Test
void should_add_multiple_embeddings_with_segments() {

TextSegment firstSegment = TextSegment.from(randomUUID());
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
TextSegment secondSegment = TextSegment.from(randomUUID());
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();

List<String> ids = embeddingStore.addAll(
asList(firstEmbedding, secondEmbedding),
asList(firstSegment, secondSegment)
);
assertThat(ids).hasSize(2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);

EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);

EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isBetween(0d, 1d);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
}

@Test
void should_find_with_min_score() {

String firstId = randomUUID();
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
embeddingStore.add(firstId, firstEmbedding);

String secondId = randomUUID();
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
embeddingStore.add(secondId, secondEmbedding);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isBetween(0d, 1d);
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);

List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
firstEmbedding,
10,
secondMatch.score() - 0.01
);
assertThat(relevant2).hasSize(2);
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);

List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
firstEmbedding,
10,
secondMatch.score()
);
assertThat(relevant3).hasSize(2);
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);

List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
firstEmbedding,
10,
secondMatch.score() + 0.01
);
assertThat(relevant4).hasSize(1);
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
}

@Test
void should_return_correct_score() {

Embedding embedding = embeddingModel.embed("hello").content();

String id = embeddingStore.add(embedding);
assertThat(id).isNotNull();

Embedding referenceEmbedding = embeddingModel.embed("hi").content();

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
assertThat(relevant).hasSize(1);

EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
withPercentage(1)
);
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}
18 changes: 18 additions & 0 deletions langchain4j-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@

</dependencies>


<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

<licenses>
<license>
<name>Apache License, Version 2.0</name>
Expand Down
Loading

0 comments on commit 16f60db

Please sign in to comment.