Skip to content

Commit

Permalink
[Feature] Pinecone: support storing metadata and embedding removal (l…
Browse files Browse the repository at this point in the history
…angchain4j#1400)

## Issue
Closes langchain4j#1169
Fixes langchain4j#1418

## Change
1. Refactor `PineconeEmbeddingStore`, update `pinecone-client` version
to latest.
2. Support storing metadata
3. Support embedding removal method (not include `removeAll(Filter filter)` because I
don't find a way to convert `Filter` to
`com.google.protobuf.StructStruct` :(, but I will try my best to work on
it, and will create a new PR if I complete it.)

## General checklist
- [ ] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)

## Checklist for changing existing embedding store integration
- [x] I have manually verified that the
`{NameOfIntegration}EmbeddingStore` works correctly with the data
persisted using the latest released version of LangChain4j
  • Loading branch information
Martin7-1 authored Jul 23, 2024
1 parent 44ac063 commit e848e6a
Show file tree
Hide file tree
Showing 10 changed files with 409 additions and 140 deletions.
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package dev.langchain4j.store.embedding;

import static dev.langchain4j.data.document.Metadata.metadata;
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.Collection;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.Collection;
import java.util.List;

import static dev.langchain4j.data.document.Metadata.metadata;
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class EmbeddingStoreWithRemovalIT {

protected abstract EmbeddingStore<TextSegment> embeddingStore();
Expand All @@ -33,6 +34,8 @@ void should_remove_by_id() {
Embedding embedding2 = embeddingModel().embed("test2").content();
String id2 = embeddingStore().add(embedding2);

awaitUntilPersisted();

assertThat(getAllEmbeddings()).hasSize(2);

// when
Expand All @@ -49,8 +52,8 @@ void should_remove_by_id() {
@ValueSource(strings = " ")
void should_fail_to_remove_by_id(String id) {
assertThatThrownBy(() -> embeddingStore().remove(id))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("id cannot be null or blank");
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("id cannot be null or blank");
}

@Test
Expand All @@ -65,6 +68,8 @@ void should_remove_all_by_ids() {
Embedding embedding3 = embeddingModel().embed("test3").content();
String id3 = embeddingStore().add(embedding3);

awaitUntilPersisted();

assertThat(getAllEmbeddings()).hasSize(3);

// when
Expand All @@ -79,15 +84,15 @@ void should_remove_all_by_ids() {
@Test
void should_fail_to_remove_all_by_ids_null() {
assertThatThrownBy(() -> embeddingStore().removeAll((Collection<String>) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
void should_fail_to_remove_all_by_ids_empty() {
assertThatThrownBy(() -> embeddingStore().removeAll(emptyList()))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("ids cannot be null or empty");
}

@Test
Expand All @@ -104,6 +109,8 @@ void should_remove_all_by_filter() {
Embedding embedding3 = embeddingModel().embed("not matching").content();
String id3 = embeddingStore().add(embedding3);

awaitUntilPersisted();

assertThat(getAllEmbeddings()).hasSize(3);

// when
Expand All @@ -118,8 +125,8 @@ void should_remove_all_by_filter() {
@Test
void should_fail_to_remove_all_by_filter_null() {
assertThatThrownBy(() -> embeddingStore().removeAll((Filter) null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("filter cannot be null");
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("filter cannot be null");
}

@Test
Expand All @@ -131,6 +138,8 @@ void should_remove_all() {
Embedding embedding2 = embeddingModel().embed("test2").content();
embeddingStore().add(embedding2);

awaitUntilPersisted();

assertThat(getAllEmbeddings()).hasSize(2);

// when
Expand All @@ -142,13 +151,17 @@ void should_remove_all() {

protected List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
.builder()
.queryEmbedding(embeddingModel().embed("test").content())
.maxResults(1000)
.build();
.builder()
.queryEmbedding(embeddingModel().embed("test").content())
.maxResults(1000)
.build();

EmbeddingSearchResult<TextSegment> searchResult = embeddingStore().search(embeddingSearchRequest);

return searchResult.matches();
}

protected void awaitUntilPersisted() {

}
}
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<aws.java.sdk.version>2.21.44</aws.java.sdk.version>
<github-api.version>1.318</github-api.version>
<milvus-sdk-java.version>2.3.6</milvus-sdk-java.version>
<netty.version>4.1.104.Final</netty.version>
<netty.version>4.1.111.Final</netty.version>
<awaitility.version>4.2.0</awaitility.version>
<jsonpath.version>2.9.0</jsonpath.version>
<infinispan.version>15.0.0.Final</infinispan.version>
Expand Down
17 changes: 3 additions & 14 deletions langchain4j-pinecone/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,17 @@
<dependency>
<groupId>io.pinecone</groupId>
<artifactId>pinecone-client</artifactId>
<version>0.6.0</version>
<version>1.2.2</version>
<exclusions>
<!-- CVE-2023-44487 -->
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
</exclusion>
<!-- CVE-2023-3635 -->
<exclusion>
<groupId>com.squareup.okio</groupId>
<artifactId>okio-jvm</artifactId>
</exclusion>
<!-- CVE-2020-29582 -->
<exclusion>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
<artifactId>netty-codec-http</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
<artifactId>netty-codec-http</artifactId>
<version>${netty.version}</version>
</dependency>

Expand Down
Loading

0 comments on commit e848e6a

Please sign in to comment.