Skip to content

Commit

Permalink
cleaned up EmbeddingStore removal functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Jun 13, 2024
1 parent b2094b4 commit 16e88df
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 331 deletions.
42 changes: 21 additions & 21 deletions docs/docs/integrations/embedding-stores/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@ hide_title: false
sidebar_position: 0
---

| Provider | Storing Metadata | Filtering by Metadata | Removing Embeddings | Local | Cloud |
|---------------------------------------------------------------------------------------|------------------|-----------------------|---------------------|-------|-------|
| [In-memory](/integrations/embedding-stores/in-memory) ||| | | |
| [Astra DB](/integrations/embedding-stores/astra-db) || | | | |
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) || | | | |
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) || | | | |
| [Cassandra](/integrations/embedding-stores/cassandra) || | | | |
| [Chroma](/integrations/embedding-stores/chroma) || | | | |
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) ||| | | |
| [Infinispan](/integrations/embedding-stores/infinispan) || | | | |
| [Milvus](/integrations/embedding-stores/milvus) ||| | | |
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) || | | | |
| [Neo4j](/integrations/embedding-stores/neo4j) | | | | | |
| [OpenSearch](/integrations/embedding-stores/opensearch) || | | | |
| [PGVector](/integrations/embedding-stores/pgvector) |||| | |
| [Pinecone](/integrations/embedding-stores/pinecone) | | | | | |
| [Qdrant](/integrations/embedding-stores/qdrant) || | | | |
| [Redis](/integrations/embedding-stores/redis) || | | | |
| [Vearch](/integrations/embedding-stores/vearch) || | | | |
| [Vespa](/integrations/embedding-stores/vespa) | | | | | |
| [Weaviate](/integrations/embedding-stores/weaviate) || | | | |
| Provider | Storing Metadata | Filtering by Metadata | Removing Embeddings |
|---------------------------------------------------------------------------------------|------------------|-----------------------|---------------------|
| [In-memory](/integrations/embedding-stores/in-memory) ||| |
| [Astra DB](/integrations/embedding-stores/astra-db) || | |
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) || | |
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) || | |
| [Cassandra](/integrations/embedding-stores/cassandra) || | |
| [Chroma](/integrations/embedding-stores/chroma) || | |
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) ||| |
| [Infinispan](/integrations/embedding-stores/infinispan) || | |
| [Milvus](/integrations/embedding-stores/milvus) ||| |
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) || | |
| [Neo4j](/integrations/embedding-stores/neo4j) | | | |
| [OpenSearch](/integrations/embedding-stores/opensearch) || | |
| [PGVector](/integrations/embedding-stores/pgvector) ||||
| [Pinecone](/integrations/embedding-stores/pinecone) | | | |
| [Qdrant](/integrations/embedding-stores/qdrant) || | |
| [Redis](/integrations/embedding-stores/redis) || | |
| [Vearch](/integrations/embedding-stores/vearch) || | |
| [Vespa](/integrations/embedding-stores/vespa) | | | |
| [Weaviate](/integrations/embedding-stores/weaviate) || | |
Original file line number Diff line number Diff line change
@@ -1,83 +1,110 @@
package dev.langchain4j.store.embedding;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.filter.Filter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class EmbeddingStoreWithRemovalIT extends EmbeddingStoreIT{
@BeforeEach
void beforeEach() {
embeddingStore().removeAll();
}
public abstract class EmbeddingStoreWithRemovalIT {

protected abstract EmbeddingStore<TextSegment> embeddingStore();

protected abstract EmbeddingModel embeddingModel();

@Test
void remove_by_id() {
Embedding embedding = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();
void should_remove_by_id() {

// given
Embedding embedding1 = embeddingModel().embed("test1").content();
String id1 = embeddingStore().add(embedding1);

String id = embeddingStore().add(embedding);
Embedding embedding2 = embeddingModel().embed("test2").content();
String id2 = embeddingStore().add(embedding2);
String id3 = embeddingStore().add(embedding3);

assertThat(id).isNotBlank();
assertThat(id2).isNotBlank();
assertThat(id3).isNotBlank();
assertThat(getAllEmbeddings()).hasSize(2);

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
assertThat(relevant).hasSize(3);
// when
embeddingStore().remove(id1);

embeddingStore().remove(id);
// then
List<EmbeddingMatch<TextSegment>> relevant = getAllEmbeddings();
assertThat(relevant).hasSize(1);
assertThat(relevant.get(0).embeddingId()).isEqualTo(id2);
}

relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = " ")
void should_fail_to_remove_by_id(String id) {

assertThatThrownBy(() -> embeddingStore().remove(id))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("id cannot be null or blank");
}

@Test
void remove_all_by_ids() {
Embedding embedding = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();
void should_remove_all_by_ids() {

// given
Embedding embedding1 = embeddingModel().embed("test1").content();
String id1 = embeddingStore().add(embedding1);

String id = embeddingStore().add(embedding);
Embedding embedding2 = embeddingModel().embed("test2").content();
String id2 = embeddingStore().add(embedding2);

Embedding embedding3 = embeddingModel().embed("test3").content();
String id3 = embeddingStore().add(embedding3);

embeddingStore().removeAll(Arrays.asList(id2, id3));
assertThat(getAllEmbeddings()).hasSize(3);

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
// when
embeddingStore().removeAll(asList(id1, id2));

// then
List<EmbeddingMatch<TextSegment>> relevant = getAllEmbeddings();
assertThat(relevant).hasSize(1);
assertThat(relevantIds).containsExactly(id);
assertThat(relevant.get(0).embeddingId()).isEqualTo(id3);
}

@Test
void remove_all_by_ids_null() {
void should_fail_to_remove_all_by_ids_null() {

assertThatThrownBy(() -> embeddingStore().removeAll((Collection<String>) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
void remove_all_by_filter() {
void should_fail_to_remove_all_by_ids_empty() {

assertThatThrownBy(() -> embeddingStore().removeAll(emptyList()))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
void should_remove_all_by_filter() {

// TODO
Metadata metadata = Metadata.metadata("id", "1");
TextSegment segment = TextSegment.from("matching", metadata);
Embedding embedding = embeddingModel().embed(segment).content();
embeddingStore().add(embedding, segment);
Embedding embedding1 = embeddingModel().embed(segment).content();
embeddingStore().add(embedding1, segment);

Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();
Expand All @@ -87,43 +114,67 @@ void remove_all_by_filter() {

embeddingStore().removeAll(metadataKey("id").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<EmbeddingMatch<TextSegment>> relevant = getAllEmbeddings();
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
}

@Test
void remove_all_by_filter_not_matching() {
Embedding embedding = embeddingModel().embed("hello").content();
void should_remove_all_by_filter_not_matching() {

// TODO
Embedding embedding1 = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();

embeddingStore().add(embedding);
embeddingStore().add(embedding1);
embeddingStore().add(embedding2);
embeddingStore().add(embedding3);

embeddingStore().removeAll(metadataKey("unknown").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<EmbeddingMatch<TextSegment>> relevant = getAllEmbeddings();
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(3);
}

@Test
void remove_all_by_filter_null() {
void should_fail_to_remove_all_by_filter_null() {

assertThatThrownBy(() -> embeddingStore().removeAll((Filter) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("filter cannot be null");
}

List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding) {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
.builder()
.queryEmbedding(embedding)
.maxResults(10)
@Test
void should_remove_all() {

// given
Embedding embedding1 = embeddingModel().embed("test1").content();
embeddingStore().add(embedding1);

Embedding embedding2 = embeddingModel().embed("test2").content();
embeddingStore().add(embedding2);

assertThat(getAllEmbeddings()).hasSize(2);

// when
embeddingStore().removeAll();

// then
assertThat(getAllEmbeddings()).isEmpty();
}

private List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {

EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddingModel().embed("test").content())
.maxResults(1000)
.build();

return embeddingStore().search(embeddingSearchRequest).matches();
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore().search(embeddingSearchRequest);

return searchResult.matches();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;

import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.*;
Expand Down Expand Up @@ -271,12 +267,6 @@ public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddin
}
}

@Override
public void remove(String id) {
ensureNotBlank(id, "id");
removeByIds(singletonList(id));
}

@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,6 @@ public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedde
return ids;
}

@Override
public void remove(String id) {
ensureNotBlank(id, "id");
String sql = String.format("DELETE FROM %s WHERE embedding_id = ?", table);
try (Connection connection = getConnection();
PreparedStatement statement = connection.prepareStatement(sql)) {
statement.setObject(1, UUID.fromString(id));
statement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
package dev.langchain4j.store.embedding.pgvector;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import dev.langchain4j.store.embedding.filter.Filter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@Testcontainers
public class PgVectorEmbeddingStoreRemoveIT extends EmbeddingStoreWithRemovalIT {
public class PgVectorEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {

@Container
static PostgreSQLContainer<?> pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15");
Expand All @@ -38,6 +23,7 @@ public class PgVectorEmbeddingStoreRemoveIT extends EmbeddingStoreWithRemovalIT
.database("test")
.table("test")
.dimension(384)
.dropTableFirst(true)
.build();

EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
Expand Down
Loading

0 comments on commit 16e88df

Please sign in to comment.