forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement remove methods for InMemoryEmbeddingStore (langchain4j#1220)
## Issue [https://github.com/langchain4j/langchain4j/issues/301](https://github.com/langchain4j/langchain4j/issues/301) ## Change I've implemented remove methods for InMemoryEmbeddingStore. Current problems: 1. I'm not completely sure how to test all of this. And, as I've seen, you don't fully test it too) 2. There is a very nasty problem with `removeAll(Filter filter)` . Let me try to break it in small pieces: 1. `EmbeddingStore`'s are parametrized with `Embedded` type parameter. 2. `removeAll(Filter filter)` method in `EmbeddingStore` accepts `Filter`. 3. `Filter` accepts `Metadata` object. 4. `Metadata` object is stored in `TextSegment`. 5. `TextSegment` is usually what is `Embedded`. But it's not guaranteed. How my pull request deals with this problem: 1) for every entry in `InMemoryEmbeddingStore` check the type of `Embedded`. 2) If it's a `TextSegment`, then proceed. 3) If it's not, then raise `UnsupportedOperationException`. The issue with this strategy is that we perform the check for every entry. I haven't found a way in Java to check the type of type parameter (because it's erasured in runtime). (*In C++, for example, there are `constexpr if`, `<type_traits>`, `std::is_same<U, V>`, explicit specialization, etc.*) We could solve this problem if `InMemoryEmbeddingStore` would support only `TextSegment` as `Embedded`, but that's a breaking change. I could also not implement this method, but, I really really need it in my project (if we chose to stick with `InMemoryEmbeddingStore`). ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [ ] I have added unit and integration tests for my change - [ ] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features)
- Loading branch information
Showing
4 changed files
with
193 additions
and
97 deletions.
There are no files selected for viewing
129 changes: 129 additions & 0 deletions
129
...ain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package dev.langchain4j.store.embedding; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collection; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
|
||
import dev.langchain4j.data.document.Metadata; | ||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.store.embedding.filter.Filter; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
|
||
public abstract class EmbeddingStoreWithRemovalIT extends EmbeddingStoreIT{ | ||
@BeforeEach | ||
void beforeEach() { | ||
embeddingStore().removeAll(); | ||
} | ||
|
||
@Test | ||
void remove_by_id() { | ||
Embedding embedding = embeddingModel().embed("hello").content(); | ||
Embedding embedding2 = embeddingModel().embed("hello2").content(); | ||
Embedding embedding3 = embeddingModel().embed("hello3").content(); | ||
|
||
String id = embeddingStore().add(embedding); | ||
String id2 = embeddingStore().add(embedding2); | ||
String id3 = embeddingStore().add(embedding3); | ||
|
||
assertThat(id).isNotBlank(); | ||
assertThat(id2).isNotBlank(); | ||
assertThat(id3).isNotBlank(); | ||
|
||
List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding); | ||
assertThat(relevant).hasSize(3); | ||
|
||
embeddingStore().remove(id); | ||
|
||
relevant = findRelevant(embedding); | ||
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); | ||
assertThat(relevantIds).hasSize(2); | ||
assertThat(relevantIds).containsExactly(id2, id3); | ||
} | ||
|
||
@Test | ||
void remove_all_by_ids() { | ||
Embedding embedding = embeddingModel().embed("hello").content(); | ||
Embedding embedding2 = embeddingModel().embed("hello2").content(); | ||
Embedding embedding3 = embeddingModel().embed("hello3").content(); | ||
|
||
String id = embeddingStore().add(embedding); | ||
String id2 = embeddingStore().add(embedding2); | ||
String id3 = embeddingStore().add(embedding3); | ||
|
||
embeddingStore().removeAll(Arrays.asList(id2, id3)); | ||
|
||
List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding); | ||
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); | ||
assertThat(relevant).hasSize(1); | ||
assertThat(relevantIds).containsExactly(id); | ||
} | ||
|
||
@Test | ||
void remove_all_by_ids_null() { | ||
assertThatThrownBy(() -> embeddingStore().removeAll((Collection<String>) null)) | ||
.isExactlyInstanceOf(IllegalArgumentException.class) | ||
.hasMessage("ids cannot be null or empty"); | ||
} | ||
|
||
@Test | ||
void remove_all_by_filter() { | ||
Metadata metadata = Metadata.metadata("id", "1"); | ||
TextSegment segment = TextSegment.from("matching", metadata); | ||
Embedding embedding = embeddingModel().embed(segment).content(); | ||
embeddingStore().add(embedding, segment); | ||
|
||
Embedding embedding2 = embeddingModel().embed("hello2").content(); | ||
Embedding embedding3 = embeddingModel().embed("hello3").content(); | ||
|
||
String id2 = embeddingStore().add(embedding2); | ||
String id3 = embeddingStore().add(embedding3); | ||
|
||
embeddingStore().removeAll(metadataKey("id").isEqualTo("1")); | ||
|
||
List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding); | ||
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); | ||
assertThat(relevantIds).hasSize(2); | ||
assertThat(relevantIds).containsExactly(id2, id3); | ||
} | ||
|
||
@Test | ||
void remove_all_by_filter_not_matching() { | ||
Embedding embedding = embeddingModel().embed("hello").content(); | ||
Embedding embedding2 = embeddingModel().embed("hello2").content(); | ||
Embedding embedding3 = embeddingModel().embed("hello3").content(); | ||
|
||
embeddingStore().add(embedding); | ||
embeddingStore().add(embedding2); | ||
embeddingStore().add(embedding3); | ||
|
||
embeddingStore().removeAll(metadataKey("unknown").isEqualTo("1")); | ||
|
||
List<EmbeddingMatch<TextSegment>> relevant = findRelevant(embedding); | ||
List<String> relevantIds = relevant.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); | ||
assertThat(relevantIds).hasSize(3); | ||
} | ||
|
||
@Test | ||
void remove_all_by_filter_null() { | ||
assertThatThrownBy(() -> embeddingStore().removeAll((Filter) null)) | ||
.isExactlyInstanceOf(IllegalArgumentException.class) | ||
.hasMessage("filter cannot be null"); | ||
} | ||
|
||
List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding) { | ||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest | ||
.builder() | ||
.queryEmbedding(embedding) | ||
.maxResults(10) | ||
.build(); | ||
|
||
return embeddingStore().search(embeddingSearchRequest).matches(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
...test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreRemovalTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package dev.langchain4j.store.embedding.inmemory; | ||
|
||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; | ||
import dev.langchain4j.model.embedding.EmbeddingModel; | ||
import dev.langchain4j.store.embedding.EmbeddingStore; | ||
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; | ||
|
||
public class InMemoryEmbeddingStoreRemovalTest extends EmbeddingStoreWithRemovalIT { | ||
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>(); | ||
|
||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); | ||
|
||
@Override | ||
protected EmbeddingStore<TextSegment> embeddingStore() { | ||
return embeddingStore; | ||
} | ||
|
||
@Override | ||
protected EmbeddingModel embeddingModel() { | ||
return embeddingModel; | ||
} | ||
} |