Skip to content

Commit

Permalink
a way to merge multiple in-memory embedding stores (langchain4j#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
dliubarskyi authored Jul 2, 2024
1 parent b74fb7d commit f601aad
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
import static java.util.Arrays.asList;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;

Expand All @@ -37,7 +38,15 @@
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {

final CopyOnWriteArrayList<Entry<Embedded>> entries = new CopyOnWriteArrayList<>();
final CopyOnWriteArrayList<Entry<Embedded>> entries;

public InMemoryEmbeddingStore() {
this.entries = new CopyOnWriteArrayList<>();
}

private InMemoryEmbeddingStore(Collection<Entry<Embedded>> entries) {
this.entries = new CopyOnWriteArrayList<>(entries);
}

@Override
public String add(Embedding embedding) {
Expand Down Expand Up @@ -189,6 +198,28 @@ public static InMemoryEmbeddingStore<TextSegment> fromFile(String filePath) {
return fromFile(Paths.get(filePath));
}

/**
* Merges given {@code InMemoryEmbeddingStore}s into a single {@code InMemoryEmbeddingStore},
* copying all entries from each store.
*/
public static <Embedded> InMemoryEmbeddingStore<Embedded> merge(Collection<InMemoryEmbeddingStore<Embedded>> stores) {
ensureNotNull(stores, "stores");
List<Entry<Embedded>> entries = new ArrayList<>();
for (InMemoryEmbeddingStore<Embedded> store : stores) {
entries.addAll(store.entries);
}
return new InMemoryEmbeddingStore<>(entries);
}

/**
* Merges given {@code InMemoryEmbeddingStore}s into a single {@code InMemoryEmbeddingStore},
* copying all entries from each store.
*/
public static <Embedded> InMemoryEmbeddingStore<Embedded> merge(InMemoryEmbeddingStore<Embedded> first,
InMemoryEmbeddingStore<Embedded> second) {
return merge(asList(first, second));
}

private static class Entry<Embedded> {

String id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,36 @@ void test_backwards_compatibility_with_0_27_1() {
assertThat(matches.get(1).embedded()).isEqualTo(expectedSegment2);
}

@Test
void should_merge_multiple_stores() {

// given
InMemoryEmbeddingStore<TextSegment> store1 = new InMemoryEmbeddingStore<>();
TextSegment segment1 = TextSegment.from("first", Metadata.from("first-key", "first-value"));
Embedding embedding1 = embeddingModel.embed(segment1).content();
store1.add("1", embedding1, segment1);

InMemoryEmbeddingStore<TextSegment> store2 = new InMemoryEmbeddingStore<>();
TextSegment segment2 = TextSegment.from("second", Metadata.from("second-key", "second-value"));
Embedding embedding2 = embeddingModel.embed(segment2).content();
store2.add("2", embedding2, segment2);

// when
InMemoryEmbeddingStore<TextSegment> merged = InMemoryEmbeddingStore.merge(store1, store2);

// then
List<EmbeddingMatch<TextSegment>> matches = merged.findRelevant(embedding1, 100);
assertThat(matches).hasSize(2);

assertThat(matches.get(0).embeddingId()).isEqualTo("1");
assertThat(matches.get(0).embedding()).isEqualTo(embedding1);
assertThat(matches.get(0).embedded()).isEqualTo(segment1);

assertThat(matches.get(1).embeddingId()).isEqualTo("2");
assertThat(matches.get(1).embedding()).isEqualTo(embedding2);
assertThat(matches.get(1).embedded()).isEqualTo(segment2);
}

private InMemoryEmbeddingStore<TextSegment> createEmbeddingStore() {

InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
Expand Down

0 comments on commit f601aad

Please sign in to comment.