Skip to content

Commit

Permalink
Implement remove methods for InMemoryEmbeddingStore (langchain4j#1220)
Browse files Browse the repository at this point in the history
## Issue
[https://github.com/langchain4j/langchain4j/issues/301](https://github.com/langchain4j/langchain4j/issues/301)

## Change
I've implemented remove methods for InMemoryEmbeddingStore.

Current problems:
1. I'm not completely sure how to test all of this. And, as I've seen,
you don't fully test it too)
2. There is a very nasty problem with `removeAll(Filter filter)` . Let
me try to break it in small pieces:
1. `EmbeddingStore`'s are parametrized with `Embedded` type parameter.
2. `removeAll(Filter filter)` method in `EmbeddingStore` accepts
`Filter`.
   3. `Filter` accepts `Metadata` object.
   4. `Metadata` object is stored in `TextSegment`.
5. `TextSegment` is usually what is `Embedded`. But it's not guaranteed.

How my pull request deals with this problem: 
1) for every entry in `InMemoryEmbeddingStore` check the type of
`Embedded`.
2) If it's a `TextSegment`, then proceed.
3) If it's not, then raise `UnsupportedOperationException`.
The issue with this strategy is that we perform the check for every
entry. I haven't found a way in Java to check the type of type parameter
(because it's erasured in runtime). (*In C++, for example, there are
`constexpr if`, `<type_traits>`, `std::is_same<U, V>`, explicit
specialization, etc.*)

We could solve this problem if `InMemoryEmbeddingStore` would support
only `TextSegment` as `Embedded`, but that's a breaking change.

I could also not implement this method, but, I really really need it in
my project (if we chose to stick with `InMemoryEmbeddingStore`).
  

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [ ] I have added unit and integration tests for my change
- [ ] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
  • Loading branch information
InAnYan authored Jun 13, 2024
1 parent 964ba81 commit 2b2c09f
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 97 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package dev.langchain4j.store.embedding;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.filter.Filter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class EmbeddingStoreWithRemovalIT extends EmbeddingStoreIT{
@BeforeEach
void beforeEach() {
embeddingStore().removeAll();
}

@Test
void remove_by_id() {
Embedding embedding = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();

String id = embeddingStore().add(embedding);
String id2 = embeddingStore().add(embedding2);
String id3 = embeddingStore().add(embedding3);

assertThat(id).isNotBlank();
assertThat(id2).isNotBlank();
assertThat(id3).isNotBlank();

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
assertThat(relevant).hasSize(3);

embeddingStore().remove(id);

relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
}

@Test
void remove_all_by_ids() {
Embedding embedding = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();

String id = embeddingStore().add(embedding);
String id2 = embeddingStore().add(embedding2);
String id3 = embeddingStore().add(embedding3);

embeddingStore().removeAll(Arrays.asList(id2, id3));

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevant).hasSize(1);
assertThat(relevantIds).containsExactly(id);
}

@Test
void remove_all_by_ids_null() {
assertThatThrownBy(() -> embeddingStore().removeAll((Collection<String>) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
void remove_all_by_filter() {
Metadata metadata = Metadata.metadata("id", "1");
TextSegment segment = TextSegment.from("matching", metadata);
Embedding embedding = embeddingModel().embed(segment).content();
embeddingStore().add(embedding, segment);

Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();

String id2 = embeddingStore().add(embedding2);
String id3 = embeddingStore().add(embedding3);

embeddingStore().removeAll(metadataKey("id").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
}

@Test
void remove_all_by_filter_not_matching() {
Embedding embedding = embeddingModel().embed("hello").content();
Embedding embedding2 = embeddingModel().embed("hello2").content();
Embedding embedding3 = embeddingModel().embed("hello3").content();

embeddingStore().add(embedding);
embeddingStore().add(embedding2);
embeddingStore().add(embedding3);

embeddingStore().removeAll(metadataKey("unknown").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(3);
}

@Test
void remove_all_by_filter_null() {
assertThatThrownBy(() -> embeddingStore().removeAll((Filter) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("filter cannot be null");
}

List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding) {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();

return embeddingStore().search(embeddingSearchRequest).matches();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import dev.langchain4j.store.embedding.filter.Filter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -24,7 +25,7 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@Testcontainers
public class PgVectorEmbeddingStoreRemoveIT {
public class PgVectorEmbeddingStoreRemoveIT extends EmbeddingStoreWithRemovalIT {

@Container
static PostgreSQLContainer<?> pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15");
Expand All @@ -41,103 +42,13 @@ public class PgVectorEmbeddingStoreRemoveIT {

EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();

@BeforeEach
void beforeEach() {
embeddingStore.removeAll();
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}

@Test
void remove_by_id() {
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();

String id = embeddingStore.add(embedding);
String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);

assertThat(id).isNotBlank();
assertThat(id2).isNotBlank();
assertThat(id3).isNotBlank();

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(3);

embeddingStore.remove(id);

relevant = embeddingStore.findRelevant(embedding, 10);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
}

@Test
void remove_all_by_ids() {
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();

String id = embeddingStore.add(embedding);
String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);

embeddingStore.removeAll(Arrays.asList(id2, id3));

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevant).hasSize(1);
assertThat(relevantIds).containsExactly(id);
}

@Test
void remove_all_by_ids_null() {
assertThatThrownBy(() -> embeddingStore.removeAll((Collection<String>) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
void remove_all_by_filter() {
Metadata metadata = Metadata.metadata("id", "1");
TextSegment segment = TextSegment.from("matching", metadata);
Embedding embedding = embeddingModel.embed(segment).content();
embeddingStore.add(embedding, segment);

Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();

String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);

embeddingStore.removeAll(metadataKey("id").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(2);
assertThat(relevantIds).containsExactly(id2, id3);
}

@Test
void remove_all_by_filter_not_matching() {
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();

embeddingStore.add(embedding);
embeddingStore.add(embedding2);
embeddingStore.add(embedding3);

embeddingStore.removeAll(metadataKey("unknown").isEqualTo("1"));

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(relevantIds).hasSize(3);
}

@Test
void remove_all_by_filter_null() {
assertThatThrownBy(() -> embeddingStore.removeAll((Filter) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("filter cannot be null");
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.nio.file.StandardOpenOption.CREATE;
Expand Down Expand Up @@ -95,6 +96,38 @@ private List<String> add(List<Entry<Embedded>> newEntries) {
.collect(toList());
}

@Override
public void remove(String id) {
entries.removeIf(entry -> entry.id.equals(id));
}

@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");

entries.removeIf(entry -> ids.contains(entry.id));
}

@Override
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");

entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}

@Override
public void removeAll() {
entries.clear();
}

@Override
public EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest embeddingSearchRequest) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.langchain4j.store.embedding.inmemory;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;

public class InMemoryEmbeddingStoreRemovalTest extends EmbeddingStoreWithRemovalIT {
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();

EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();

@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}

@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}

0 comments on commit 2b2c09f

Please sign in to comment.