From babd64a940fe5fadd1e6b6985b631dcf52726e12 Mon Sep 17 00:00:00 2001 From: ZYinNJU <1754350460@qq.com> Date: Mon, 29 Jan 2024 15:04:04 +0800 Subject: [PATCH] Integration with vearch (#525) support #425 . Due to my local environment problem (`vearch` docker container start failed in Apple M1), I do the integration test in remote `vearch` (I start up `vearch` container in remote host using docker), and it works fine. (But I don't check using `Testcontainers` to start up) Two more things need discussion and your opinion: 1. There is a translation between `RelevantScore` and `CosineSimilarity` in `findRelevant` method, I don't know if that's correct, because `vearch` do not support cosine similarity, so I use inner product instead (same as cosine similarity if vector is normalized). Should we normalize vector before adding it to the embedding store? 2. There are many contraints in creating `vearch` space (retrieval types have different parameters). Should we check it or just let users to check themselves? (see [Create Space](https://vearch.readthedocs.io/en/latest/use_op/op_space.html#create-space)). Currently I implement it by using many inner static class (see `RetrievalParam` and `RetrievalType`, in `SpaceEngine` it will do some constraint check.) --- langchain4j-bom/pom.xml | 18 ++ .../langchain4j/data/embedding/Embedding.java | 15 + .../data/embedding/EmbeddingTest.java | 9 + langchain4j-vearch/pom.xml | 101 +++++++ .../store/embedding/vearch/BulkRequest.java | 16 ++ .../store/embedding/vearch/BulkResponse.java | 17 ++ .../vearch/CreateDatabaseRequest.java | 13 + .../vearch/CreateDatabaseResponse.java | 14 + .../embedding/vearch/CreateSpaceRequest.java | 24 ++ .../embedding/vearch/CreateSpaceResponse.java | 14 + .../vearch/ListDatabaseResponse.java | 14 + .../embedding/vearch/ListSpaceResponse.java | 14 + .../store/embedding/vearch/MetricType.java | 17 ++ .../store/embedding/vearch/ModelParam.java | 17 ++ .../embedding/vearch/ResponseWrapper.java | 15 + .../embedding/vearch/RetrievalParam.java | 116 ++++++++ .../store/embedding/vearch/RetrievalType.java | 20 ++ .../store/embedding/vearch/SearchRequest.java | 35 +++ .../embedding/vearch/SearchResponse.java | 48 ++++ .../store/embedding/vearch/SpaceEngine.java | 72 +++++ .../embedding/vearch/SpacePropertyParam.java | 124 ++++++++ .../embedding/vearch/SpacePropertyType.java | 26 ++ .../embedding/vearch/SpaceStoreParam.java | 25 ++ .../embedding/vearch/SpaceStoreType.java | 14 + .../store/embedding/vearch/VearchApi.java | 44 +++ .../store/embedding/vearch/VearchClient.java | 202 +++++++++++++ .../store/embedding/vearch/VearchConfig.java | 64 +++++ .../vearch/VearchEmbeddingStore.java | 270 ++++++++++++++++++ .../vearch/DeleteSpaceLastOrderer.java | 32 +++ .../vearch/VearchEmbeddingStoreIT.java | 146 ++++++++++ .../src/test/resources/config.toml | 86 ++++++ pom.xml | 1 + 32 files changed, 1643 insertions(+) create mode 100644 langchain4j-vearch/pom.xml create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkRequest.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseRequest.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceRequest.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListDatabaseResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListSpaceResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/MetricType.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ModelParam.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ResponseWrapper.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalParam.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalType.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchRequest.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchResponse.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceEngine.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyParam.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyType.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreParam.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreType.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchApi.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchClient.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchConfig.java create mode 100644 langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStore.java create mode 100644 langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/DeleteSpaceLastOrderer.java create mode 100644 langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java create mode 100644 langchain4j-vearch/src/test/resources/config.toml diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml index ee0767287c7..616597ef599 100644 --- a/langchain4j-bom/pom.xml +++ b/langchain4j-bom/pom.xml @@ -185,6 +185,12 @@ ${project.version} + + dev.langchain4j + langchain4j-vearch + ${project.version} + + @@ -207,6 +213,18 @@ ${project.version} + + dev.langchain4j + langchain4j-document-loader-github + ${project.version} + + + + dev.langchain4j + langchain4j-document-loader-azure-storage-blob + ${project.version} + + diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/embedding/Embedding.java b/langchain4j-core/src/main/java/dev/langchain4j/data/embedding/Embedding.java index e81725cefa3..3cc920dcf12 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/embedding/Embedding.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/embedding/Embedding.java @@ -45,6 +45,21 @@ public List vectorAsList() { return list; } + /** + * Normalize vector + */ + public void normalize() { + double norm = 0.0; + for (float f : vector) { + norm += f * f; + } + norm = Math.sqrt(norm); + + for (int i = 0; i < vector.length; i++) { + vector[i] /= norm; + } + } + /** * Returns the dimension of the vector. * @return the dimension of the vector. diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/embedding/EmbeddingTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/embedding/EmbeddingTest.java index 104579ed91c..f3c446f8fc8 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/embedding/EmbeddingTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/embedding/EmbeddingTest.java @@ -48,4 +48,13 @@ public void test_from() { .isEqualTo(new Embedding(new float[]{1.0f, 2.0f, 3.0f})); } + @Test + void test_normalize() { + Embedding embedding = new Embedding(new float[]{6f, 8f}); + embedding.normalize(); + + Embedding expect = new Embedding(new float[]{0.6f, 0.8f}); + assertThat(embedding).isEqualTo(expect); + } + } \ No newline at end of file diff --git a/langchain4j-vearch/pom.xml b/langchain4j-vearch/pom.xml new file mode 100644 index 00000000000..c886c65754a --- /dev/null +++ b/langchain4j-vearch/pom.xml @@ -0,0 +1,101 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.26.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-vearch + jar + + LangChain4j integration with Vearch + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-gson + + + + com.squareup.okhttp3 + okhttp + + + + org.projectlombok + lombok + provided + + + + org.slf4j + slf4j-api + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.assertj + assertj-core + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + \ No newline at end of file diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkRequest.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkRequest.java new file mode 100644 index 00000000000..4e994292b0e --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkRequest.java @@ -0,0 +1,16 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; +import java.util.Map; + +@Getter +@Setter +@Builder +class BulkRequest { + + private List> documents; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkResponse.java new file mode 100644 index 00000000000..bdb4290f736 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/BulkResponse.java @@ -0,0 +1,17 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class BulkResponse { + + private Integer status; + private String error; + @SerializedName("_id") + private String id; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseRequest.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseRequest.java new file mode 100644 index 00000000000..5edd1602834 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseRequest.java @@ -0,0 +1,13 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class CreateDatabaseRequest { + + private String name; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseResponse.java new file mode 100644 index 00000000000..34bb58f0a5d --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateDatabaseResponse.java @@ -0,0 +1,14 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class CreateDatabaseResponse { + + private Long id; + private String name; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceRequest.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceRequest.java new file mode 100644 index 00000000000..187fc80171a --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceRequest.java @@ -0,0 +1,24 @@ +package dev.langchain4j.store.embedding.vearch; + +import dev.langchain4j.store.embedding.vearch.ModelParam; +import dev.langchain4j.store.embedding.vearch.SpaceEngine; +import dev.langchain4j.store.embedding.vearch.SpacePropertyParam; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; +import java.util.Map; + +@Getter +@Setter +@Builder +class CreateSpaceRequest { + + private String name; + private Integer partitionNum; + private Integer replicaNum; + private SpaceEngine engine; + private Map properties; + private List models; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceResponse.java new file mode 100644 index 00000000000..74f2d958b75 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/CreateSpaceResponse.java @@ -0,0 +1,14 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class CreateSpaceResponse { + + private Integer id; + private String name; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListDatabaseResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListDatabaseResponse.java new file mode 100644 index 00000000000..68747fb1468 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListDatabaseResponse.java @@ -0,0 +1,14 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class ListDatabaseResponse { + + private Integer id; + private String name; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListSpaceResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListSpaceResponse.java new file mode 100644 index 00000000000..ea9479a43c4 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ListSpaceResponse.java @@ -0,0 +1,14 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +public class ListSpaceResponse { + + private Integer id; + private String name; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/MetricType.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/MetricType.java new file mode 100644 index 00000000000..a18a2a8afea --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/MetricType.java @@ -0,0 +1,17 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; + +/** + * if metric type is not set when searching, it will use the parameter specified when building the space + * + *

LangChain4j currently only support {@link MetricType#INNER_PRODUCT}

+ */ +public enum MetricType { + + /** + * Inner Product + */ + @SerializedName("InnerProduct") + INNER_PRODUCT +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ModelParam.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ModelParam.java new file mode 100644 index 00000000000..22054ade299 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ModelParam.java @@ -0,0 +1,17 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Getter +@Setter +@Builder +public class ModelParam { + + private String modelId; + private List fields; + private String out; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ResponseWrapper.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ResponseWrapper.java new file mode 100644 index 00000000000..a74defe43e7 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/ResponseWrapper.java @@ -0,0 +1,15 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +class ResponseWrapper { + + private Integer code; + private String msg; + private T data; +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalParam.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalParam.java new file mode 100644 index 00000000000..b35db1d0c0b --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalParam.java @@ -0,0 +1,116 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * As a constraint of all engine type only + * + * @see RetrievalType + */ +public abstract class RetrievalParam { + + @Getter + @Setter + @Builder + public static class IVFPQParam extends RetrievalParam { + + @Builder.Default + private MetricType metricType = MetricType.INNER_PRODUCT; + /** + * number of buckets for indexing + * + *

default 2048

+ */ + private Integer ncentroids; + /** + * the number of sub vector + * + *

default 64, must be a multiple of 4

+ */ + private Integer nsubvector; + } + + @Getter + @Setter + @Builder + public static class HNSWParam extends RetrievalParam { + + @Builder.Default + private MetricType metricType = MetricType.INNER_PRODUCT; + /** + * neighbors number of each node + * + *

default 32

+ */ + private Integer nlinks; + /** + * expansion factor at construction time + * + *

default 40

+ *

The higher the value, the better the construction effect, and the longer it takes

+ */ + @SerializedName("efConstruction") + private Integer efConstruction; + } + + @Getter + @Setter + @Builder + public static class GPUParam extends RetrievalParam { + + @Builder.Default + private MetricType metricType = MetricType.INNER_PRODUCT; + /** + * number of buckets for indexing + * + *

default 2048

+ */ + private Integer ncentroids; + /** + * the number of sub vector + * + *

default 64

+ */ + private Integer nsubvector; + } + + @Getter + @Setter + @Builder + public static class IVFFLATParam extends RetrievalParam { + + @Builder.Default + private MetricType metricType = MetricType.INNER_PRODUCT; + /** + * number of buckets for indexing + * + *

default 2048

+ */ + private Integer ncentroids; + } + + @Getter + @Setter + @Builder + public static class BINARYIVFParam extends RetrievalParam { + + /** + * coarse cluster center number + * + *

default 256

+ */ + private Integer ncentroids; + } + + @Getter + @Setter + @Builder + public static class FLAT extends RetrievalParam { + + @Builder.Default + private MetricType metricType = MetricType.INNER_PRODUCT; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalType.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalType.java new file mode 100644 index 00000000000..63538bb16a4 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/RetrievalType.java @@ -0,0 +1,20 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Getter; + +public enum RetrievalType { + + IVFPQ(RetrievalParam.IVFPQParam.class), + HNSW(RetrievalParam.HNSWParam.class), + GPU(RetrievalParam.GPUParam.class), + IVFFLAT(RetrievalParam.IVFFLATParam.class), + BINARYIVF(RetrievalParam.BINARYIVFParam.class), + FLAT(RetrievalParam.FLAT.class); + + @Getter + private Class paramClass; + + RetrievalType(Class paramClass) { + this.paramClass = paramClass; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchRequest.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchRequest.java new file mode 100644 index 00000000000..fa8abc93638 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchRequest.java @@ -0,0 +1,35 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Getter +@Setter +@Builder +class SearchRequest { + + private QueryParam query; + private Integer size; + private List fields; + + @Getter + @Setter + @Builder + public static class QueryParam { + + private List sum; + } + + @Getter + @Setter + @Builder + public static class VectorParam { + + private String field; + private List feature; + private Double minScore; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchResponse.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchResponse.java new file mode 100644 index 00000000000..00dbf7006a6 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SearchResponse.java @@ -0,0 +1,48 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; +import java.util.Map; + +@Getter +@Setter +@Builder +class SearchResponse { + + private Integer took; + @SerializedName("timed_out") + private Boolean timeout; + /** + * not support shards yet + */ + @SerializedName("_shards") + private Object shards; + private Hit hits; + + @Getter + @Setter + @Builder + public static class Hit { + + private Integer total; + private Double maxScore; + private List hits; + } + + @Getter + @Setter + @Builder + public static class SearchedDocument { + + @SerializedName("_id") + private String id; + @SerializedName("_score") + private Double score; + @SerializedName("_source") + private Map source; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceEngine.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceEngine.java new file mode 100644 index 00000000000..1ede675683d --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceEngine.java @@ -0,0 +1,72 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class SpaceEngine { + + private String name; + private Long indexSize; + private RetrievalType retrievalType; + private RetrievalParam retrievalParam; + + public SpaceEngine() { + + } + + public SpaceEngine(String name, Long indexSize, RetrievalType retrievalType, RetrievalParam retrievalParam) { + setName(name); + setIndexSize(indexSize); + setRetrievalType(retrievalType); + setRetrievalParam(retrievalParam); + } + + public void setRetrievalParam(RetrievalParam retrievalParam) { + // do some constraint check + Class clazz = retrievalType.getParamClass(); + if (!clazz.isInstance(retrievalParam)) { + throw new UnsupportedOperationException( + String.format("can't assign unknown param of engine %s, please use class %s to assign engine param", + retrievalType.name(), clazz.getSimpleName())); + } + this.retrievalParam = retrievalParam; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + private Long indexSize; + private RetrievalType retrievalType; + private RetrievalParam retrievalParam; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder indexSize(Long indexSize) { + this.indexSize = indexSize; + return this; + } + + public Builder retrievalType(RetrievalType retrievalType) { + this.retrievalType = retrievalType; + return this; + } + + public Builder retrievalParam(RetrievalParam retrievalParam) { + this.retrievalParam = retrievalParam; + return this; + } + + public SpaceEngine build() { + return new SpaceEngine(name, indexSize, retrievalType, retrievalParam); + } + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyParam.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyParam.java new file mode 100644 index 00000000000..ad901981d22 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyParam.java @@ -0,0 +1,124 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * As a constraint type of all Space property only + * + * @see CreateSpaceRequest + */ +public abstract class SpacePropertyParam { + + protected SpacePropertyType type; + + SpacePropertyParam(SpacePropertyType type) { + this.type = type; + } + + @Getter + @Setter + public static class StringParam extends SpacePropertyParam { + + /** + * whether to create an index + */ + private Boolean index; + /** + * whether to allow multipart value + */ + private Boolean array; + + public StringParam() { + super(SpacePropertyType.STRING); + } + + @Builder + public StringParam(Boolean index, Boolean array) { + this(); + this.index = index; + this.array = array; + } + } + + @Getter + @Setter + public static class IntegerParam extends SpacePropertyParam { + + /** + * whether to create an index + * + *

set to true to support the use of numeric range filtering queries (not supported in langchain4j now)

+ */ + private Boolean index; + + public IntegerParam() { + super(SpacePropertyType.INTEGER); + } + + @Builder + public IntegerParam(Boolean index) { + this(); + this.index = index; + } + } + + @Getter + @Setter + public static class FloatParam extends SpacePropertyParam { + + /** + * whether to create an index + * + *

set to true to support the use of numeric range filtering queries (not supported in langchain4j now)

+ */ + private Boolean index; + + public FloatParam() { + super(SpacePropertyType.FLOAT); + } + + @Builder + public FloatParam(Boolean index) { + this(); + this.index = index; + } + } + + @Getter + @Setter + public static class VectorParam extends SpacePropertyParam { + + private Boolean index; + private Integer dimension; + /** + * "RocksDB" or "MemoryOnly". For HNSW and IVFFLAT and FLAT, it can only be run in MemoryOnly mode. + * + * @see SpaceStoreType + */ + private SpaceStoreType storeType; + private SpaceStoreParam storeParam; + private String modelId; + /** + * default not normalized. if you set "normalization", "normal" it will normalized + */ + private String format; + + public VectorParam() { + super(SpacePropertyType.VECTOR); + } + + @Builder + public VectorParam(Boolean index, Integer dimension, SpaceStoreType storeType, + SpaceStoreParam storeParam, String modelId, String format) { + this(); + this.index = index; + this.dimension = dimension; + this.storeType = storeType; + this.storeParam = storeParam; + this.modelId = modelId; + this.format = format; + } + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyType.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyType.java new file mode 100644 index 00000000000..de5f764708c --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpacePropertyType.java @@ -0,0 +1,26 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; +import lombok.Getter; + +public enum SpacePropertyType { + + /** + * keyword is equivalent to string + */ + @SerializedName("string") + STRING(SpacePropertyParam.StringParam.class), + @SerializedName("integer") + INTEGER(SpacePropertyParam.IntegerParam.class), + @SerializedName("float") + FLOAT(SpacePropertyParam.FloatParam.class), + @SerializedName("vector") + VECTOR(SpacePropertyParam.VectorParam.class); + + @Getter + private final Class paramClass; + + SpacePropertyType(Class paramClass) { + this.paramClass = paramClass; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreParam.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreParam.java new file mode 100644 index 00000000000..48b4075bd83 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreParam.java @@ -0,0 +1,25 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@Builder +public class SpaceStoreParam { + + /** + * It means you will use so much memory, the excess will be kept to disk. For MemoryOnly, this parameter is invalid. + */ + private Integer cacheSize; + private CompressRate compress; + + @Getter + @Setter + @Builder + public static class CompressRate { + + private Integer rate; + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreType.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreType.java new file mode 100644 index 00000000000..0b09e6927e1 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/SpaceStoreType.java @@ -0,0 +1,14 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.annotations.SerializedName; +import lombok.Getter; + +public enum SpaceStoreType { + + @SerializedName("MemoryOnly") + MEMORY_ONLY, + @SerializedName("Mmap") + M_MAP, + @SerializedName("RocksDB") + ROCKS_DB +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchApi.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchApi.java new file mode 100644 index 00000000000..ea522fc963b --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchApi.java @@ -0,0 +1,44 @@ +package dev.langchain4j.store.embedding.vearch; + +import okhttp3.RequestBody; +import retrofit2.Call; +import retrofit2.http.*; + +import java.util.List; + +public interface VearchApi { + + int OK = 200; + + /* Database Operation */ + + @GET("/list/db") + Call>> listDatabase(); + + @PUT("/db/_create") + Call> createDatabase(@Body CreateDatabaseRequest request); + + @GET("/list/space") + Call>> listSpaceOfDatabase(@Query("db") String dbName); + + /* Space (like a table in relational database) Operation */ + + @PUT("/space/{db}/_create") + Call> createSpace(@Path("db") String dbName, + @Body CreateSpaceRequest request); + + /* Document Operation */ + + @POST("/{db}/{space}/_bulk") + Call> bulk(@Path("db") String db, + @Path("space") String space, + @Body RequestBody requestBody); + + @POST("/{db}/{space}/_search") + Call search(@Path("db") String db, + @Path("space") String space, + @Body SearchRequest request); + + @DELETE("/space/{db}/{space}") + Call deleteSpace(@Path("db") String dbName, @Path("space") String spaceName); +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchClient.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchClient.java new file mode 100644 index 00000000000..57a114cdd61 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchClient.java @@ -0,0 +1,202 @@ +package dev.langchain4j.store.embedding.vearch; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import lombok.Builder; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.RequestBody; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.gson.GsonConverterFactory; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; +import static dev.langchain4j.store.embedding.vearch.VearchApi.OK; + +class VearchClient { + + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) + .create(); + + private final VearchApi vearchApi; + + @Builder + public VearchClient(String baseUrl, Duration timeout) { + OkHttpClient okHttpClient = new OkHttpClient.Builder() + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout) + .build(); + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(baseUrl) + .client(okHttpClient) + .addConverterFactory(GsonConverterFactory.create(GSON)) + .build(); + + vearchApi = retrofit.create(VearchApi.class); + } + + public List listDatabase() { + try { + Response>> response = vearchApi.listDatabase().execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper> wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public CreateDatabaseResponse createDatabase(CreateDatabaseRequest request) { + try { + Response> response = vearchApi.createDatabase(request).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public List listSpace(String dbName) { + try { + Response>> response = vearchApi.listSpaceOfDatabase(dbName).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper> wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public CreateSpaceResponse createSpace(String dbName, CreateSpaceRequest request) { + try { + Response> response = vearchApi.createSpace(dbName, request).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void bulk(String dbName, String spaceName, BulkRequest request) { + try { + StringBuilder bodyString = new StringBuilder(); + for (Map document : request.getDocuments()) { + Map fieldsExceptId = new HashMap<>(); + for (Map.Entry entry : document.entrySet()) { + String fieldName = entry.getKey(); + Object value = entry.getValue(); + + if ("_id".equals(fieldName)) { + bodyString.append("{\"index\": {\"_id\": \"").append(value).append("\"}}\n"); + } else { + fieldsExceptId.put(fieldName, value); + } + } + bodyString.append(GSON.toJson(fieldsExceptId)).append("\n"); + } + RequestBody body = RequestBody.create(bodyString.toString(), MediaType.parse("application/json; charset=utf-8")); + Response> response = vearchApi.bulk(dbName, spaceName, body).execute(); + + if (response.isSuccessful() && response.body() != null) { + List bulkResponses = response.body(); + bulkResponses.forEach(bulkResponse -> { + if (bulkResponse.getStatus() != OK) { + throw toException(bulkResponse.getStatus(), bulkResponse.getError()); + } + }); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public SearchResponse search(String dbName, String spaceName, SearchRequest request) { + try { + Response response = vearchApi.search(dbName, spaceName, request).execute(); + + if (response.isSuccessful() && response.body() != null) { + SearchResponse searchResponse = response.body(); + if (Boolean.TRUE.equals(searchResponse.getTimeout())) { + throw new RuntimeException("Search Timeout"); + } + return searchResponse; + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void deleteSpace(String databaseName, String spaceName) { + try { + Response response = vearchApi.deleteSpace(databaseName, spaceName).execute(); + + if (!response.isSuccessful()) { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private RuntimeException toException(Response response) throws IOException { + int code = response.code(); + String body = response.errorBody().string(); + + String errorMessage = String.format("status code: %s; body: %s", code, body); + return new RuntimeException(errorMessage); + } + + private RuntimeException toException(ResponseWrapper responseWrapper) { + return toException(responseWrapper.getCode(), responseWrapper.getMsg()); + } + + private RuntimeException toException(int code, String msg) { + String errorMessage = String.format("code: %s; message: %s", code, msg); + + return new RuntimeException(errorMessage); + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchConfig.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchConfig.java new file mode 100644 index 00000000000..609b717bf71 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchConfig.java @@ -0,0 +1,64 @@ +package dev.langchain4j.store.embedding.vearch; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; + +@Getter +@Setter +@Builder +public class VearchConfig { + + private String databaseName; + private String spaceName; + private SpaceEngine spaceEngine; + /** + * This attribute's key set should contain + * {@link VearchConfig#embeddingFieldName}, {@link VearchConfig#textFieldName} and {@link VearchConfig#metadataFieldNames} + */ + private Map properties; + @Builder.Default + private String embeddingFieldName = "embedding"; + @Builder.Default + private String textFieldName = "text"; + private List modelParams; + /** + * This attribute should be the subset of {@link VearchConfig#properties}'s key set + */ + private List metadataFieldNames; + + public static VearchConfig getDefaultConfig() { + // init properties + Map properties = new HashMap<>(4); + properties.put("embedding", SpacePropertyParam.VectorParam.builder() + .index(true) + .storeType(SpaceStoreType.MEMORY_ONLY) + .dimension(384) + .build()); + properties.put("text", SpacePropertyParam.StringParam.builder().build()); + + return VearchConfig.builder() + .spaceEngine(SpaceEngine.builder() + .name("gamma") + .indexSize(1L) + .retrievalType(RetrievalType.FLAT) + .retrievalParam(RetrievalParam.FLAT.builder() + .build()) + .build()) + .properties(properties) + .databaseName("embedding_db") + .spaceName("embedding_space") + .modelParams(singletonList(ModelParam.builder() + .modelId("vgg16") + .fields(singletonList("string")) + .out("feature") + .build())) + .build(); + } +} diff --git a/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStore.java b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStore.java new file mode 100644 index 00000000000..8c3bd8c16a0 --- /dev/null +++ b/langchain4j-vearch/src/main/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStore.java @@ -0,0 +1,270 @@ +package dev.langchain4j.store.embedding.vearch; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; + +import java.time.Duration; +import java.util.*; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.*; +import static java.time.Duration.ofSeconds; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +public class VearchEmbeddingStore implements EmbeddingStore { + + private final VearchConfig vearchConfig; + private final VearchClient vearchClient; + /** + * whether to normalize embedding when add to embedding store + */ + private final boolean normalizeEmbeddings; + + public VearchEmbeddingStore(String baseUrl, + Duration timeout, + VearchConfig vearchConfig, + Boolean normalizeEmbeddings) { + // Step 0: initialize some attribute + baseUrl = ensureNotNull(baseUrl, "baseUrl"); + this.vearchConfig = getOrDefault(vearchConfig, VearchConfig.getDefaultConfig()); + this.normalizeEmbeddings = getOrDefault(normalizeEmbeddings, false); + + vearchClient = VearchClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .build(); + + // Step 1: check whether db exist, if not, create it + if (!isDatabaseExist(this.vearchConfig.getDatabaseName())) { + createDatabase(this.vearchConfig.getDatabaseName()); + } + + // Step 2: check whether space exist, if not, create it + if (!isSpaceExist(this.vearchConfig.getDatabaseName(), this.vearchConfig.getSpaceName())) { + createSpace(this.vearchConfig.getDatabaseName(), this.vearchConfig.getSpaceName()); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private VearchConfig vearchConfig; + private String baseUrl; + private Duration timeout; + private Boolean normalizeEmbeddings; + + public Builder vearchConfig(VearchConfig vearchConfig) { + this.vearchConfig = vearchConfig; + return this; + } + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + /** + * Set whether to normalize embedding when add to embedding store + * + * @param normalizeEmbeddings whether to normalize embedding when add to embedding store + * @return builder + */ + public Builder normalizeEmbeddings(Boolean normalizeEmbeddings) { + this.normalizeEmbeddings = normalizeEmbeddings; + return this; + } + + public VearchEmbeddingStore build() { + return new VearchEmbeddingStore(baseUrl, timeout, vearchConfig, normalizeEmbeddings); + } + } + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + double minSimilarity = CosineSimilarity.fromRelevanceScore(minScore); + List fields = new ArrayList<>(Arrays.asList(vearchConfig.getTextFieldName(), vearchConfig.getEmbeddingFieldName())); + fields.addAll(vearchConfig.getMetadataFieldNames()); + SearchRequest request = SearchRequest.builder() + .query(SearchRequest.QueryParam.builder() + .sum(singletonList(SearchRequest.VectorParam.builder() + .field(vearchConfig.getEmbeddingFieldName()) + .feature(referenceEmbedding.vectorAsList()) + .minScore(minSimilarity) + .build())) + .build()) + .size(maxResults) + .fields(fields) + .build(); + + SearchResponse response = vearchClient.search(vearchConfig.getDatabaseName(), vearchConfig.getSpaceName(), request); + return toEmbeddingMatch(response.getHits()); + } + + public void deleteSpace() { + vearchClient.deleteSpace(vearchConfig.getDatabaseName(), vearchConfig.getSpaceName()); + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal(List ids, List embeddings, List embedded) { + ids = ensureNotEmpty(ids, "ids"); + embeddings = ensureNotEmpty(embeddings, "embeddings"); + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); + + List> documents = new ArrayList<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + Map document = new HashMap<>(4); + document.put("_id", ids.get(i)); + Map> embeddingValue = new HashMap<>(1); + Embedding embedding = embeddings.get(i); + if (normalizeEmbeddings) { + embedding.normalize(); + } + embeddingValue.put("feature", embedding.vectorAsList()); + document.put(vearchConfig.getEmbeddingFieldName(), embeddingValue); + if (embedded != null) { + document.put(vearchConfig.getTextFieldName(), embedded.get(i).text()); + Map metadata = embedded.get(i).metadata().asMap(); + for (String metadataFieldName : vearchConfig.getMetadataFieldNames()) { + metadata.putIfAbsent(metadataFieldName, ""); + } + document.putAll(metadata); + } else { + // vearch do not allow nullable value + document.put(vearchConfig.getTextFieldName(), ""); + if (!isNullOrEmpty(vearchConfig.getMetadataFieldNames())) { + for (String metadataFieldName : vearchConfig.getMetadataFieldNames()) { + document.put(metadataFieldName, ""); + } + } + } + documents.add(document); + } + BulkRequest request = BulkRequest.builder() + .documents(documents) + .build(); + vearchClient.bulk(vearchConfig.getDatabaseName(), vearchConfig.getSpaceName(), request); + } + + private boolean isDatabaseExist(String databaseName) { + List databases = vearchClient.listDatabase(); + return databases.stream().anyMatch(database -> databaseName.equals(database.getName())); + } + + private void createDatabase(String databaseName) { + vearchClient.createDatabase(CreateDatabaseRequest.builder() + .name(databaseName) + .build()); + } + + private boolean isSpaceExist(String databaseName, String spaceName) { + List spaces = vearchClient.listSpace(databaseName); + return spaces.stream().anyMatch(space -> spaceName.equals(space.getName())); + } + + private void createSpace(String databaseName, String space) { + vearchClient.createSpace(databaseName, CreateSpaceRequest.builder() + .name(space) + .engine(vearchConfig.getSpaceEngine()) + .replicaNum(1) + .partitionNum(1) + .properties(vearchConfig.getProperties()) + .models(vearchConfig.getModelParams()) + .build()); + } + + @SuppressWarnings("unchecked") + private List> toEmbeddingMatch(SearchResponse.Hit hit) { + List searchedDocuments = hit.getHits(); + if (isNullOrEmpty(searchedDocuments)) { + return new ArrayList<>(); + } + + return searchedDocuments.stream().map(searchedDocument -> { + Map source = searchedDocument.getSource(); + String id = searchedDocument.getId(); + List vector = (List) ((Map) source.get(vearchConfig.getEmbeddingFieldName())).get("feature"); + Embedding embedding = Embedding.from(vector.stream().map(Double::floatValue).collect(toList())); + + TextSegment textSegment = null; + String text = source.get(vearchConfig.getTextFieldName()) == null ? null : String.valueOf(source.get(vearchConfig.getTextFieldName())); + if (!isNullOrBlank(text)) { + Map metadataMap = convertMetadataMap(source); + textSegment = TextSegment.from(text, Metadata.from(metadataMap)); + } + + return new EmbeddingMatch<>(RelevanceScore.fromCosineSimilarity(searchedDocument.getScore()), id, embedding, textSegment); + }).collect(toList()); + } + + private Map convertMetadataMap(Map source) { + // Whether there are potential risk in removing fields directly + source.remove(vearchConfig.getTextFieldName()); + source.remove(vearchConfig.getEmbeddingFieldName()); + if (source.isEmpty()) { + return new HashMap<>(); + } + Map metadataMap = new HashMap<>(source.size()); + source.forEach((key, value) -> { + if (!isNullOrBlank(String.valueOf(value))) { + metadataMap.put(key, String.valueOf(value)); + } + }); + + return metadataMap; + } +} diff --git a/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/DeleteSpaceLastOrderer.java b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/DeleteSpaceLastOrderer.java new file mode 100644 index 00000000000..109b1e2b4f6 --- /dev/null +++ b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/DeleteSpaceLastOrderer.java @@ -0,0 +1,32 @@ +package dev.langchain4j.store.embedding.vearch; + +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.MethodOrdererContext; + +import java.util.ArrayList; +import java.util.List; + +public class DeleteSpaceLastOrderer implements MethodOrderer { + @Override + public void orderMethods(MethodOrdererContext methodOrdererContext) { + // should equal to VearchEmbeddingStoreIT#should_delete_space test name + String deleteSpaceTestName = "should_delete_space"; + List methodNames = new ArrayList<>(); + methodOrdererContext.getMethodDescriptors().forEach(methodDescriptor -> methodNames.add(methodDescriptor.getMethod().getName())); + methodNames.sort((methodName1, methodName2) -> { + // + if (methodName1.equals(deleteSpaceTestName)) { + return 1; + } else if (methodName2.equals(deleteSpaceTestName)) { + return -1; + } else { + return 0; + } + }); + methodOrdererContext.getMethodDescriptors().sort((md1, md2) -> { + int index1 = methodNames.indexOf(md1.getMethod().getName()); + int index2 = methodNames.indexOf(md2.getMethod().getName()); + return Integer.compare(index1, index2); + }); + } +} diff --git a/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java new file mode 100644 index 00000000000..602d4bb64e1 --- /dev/null +++ b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java @@ -0,0 +1,146 @@ +package dev.langchain4j.store.embedding.vearch; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.testcontainers.containers.BindMode; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.DockerImageName; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +@TestMethodOrder(DeleteSpaceLastOrderer.class) +public class VearchEmbeddingStoreIT extends EmbeddingStoreIT { + + static String configPath = VearchEmbeddingStoreIT.class.getClassLoader().getResource("config.toml").getPath(); + static GenericContainer vearch = new GenericContainer<>(DockerImageName.parse("vearch/vearch:latest")) + .withCommand("all") + .withFileSystemBind(configPath, "/vearch/config.toml", BindMode.READ_ONLY) + .waitingFor(Wait.forLogMessage(".*INFO : server pid:1.*\\n", 1)); + + VearchEmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + /** + * in order to clear embedding store + */ + VearchClient vearchClient; + + String databaseName; + + String spaceName; + + VearchConfig vearchConfig; + + public VearchEmbeddingStoreIT() { + String embeddingFieldName = "text_embedding"; + String textFieldName = "text"; + String metadataFieldName = "test-key"; + + this.databaseName = "embedding_db"; + this.spaceName = "embedding_space"; + + // init properties + Map properties = new HashMap<>(4); + properties.put(embeddingFieldName, SpacePropertyParam.VectorParam.builder() + .index(true) + .storeType(SpaceStoreType.MEMORY_ONLY) + .dimension(384) + .build()); + properties.put(textFieldName, SpacePropertyParam.StringParam.builder().build()); + // metadata + properties.put(metadataFieldName, SpacePropertyParam.StringParam.builder().build()); + + // init vearch config + this.vearchConfig = VearchConfig.builder() + .spaceEngine(SpaceEngine.builder() + .name("gamma") + .indexSize(1L) + .retrievalType(RetrievalType.FLAT) + .retrievalParam(RetrievalParam.FLAT.builder() + .build()) + .build()) + .properties(properties) + .embeddingFieldName(embeddingFieldName) + .textFieldName(textFieldName) + .databaseName(this.databaseName) + .spaceName(this.spaceName) + .modelParams(singletonList(ModelParam.builder() + .modelId("vgg16") + .fields(singletonList("string")) + .out("feature") + .build())) + .metadataFieldNames(singletonList(metadataFieldName)) + .build(); + + // init embedding store and vearch client + String baseUrl = "http://" + vearch.getHost() + ":" + vearch.getMappedPort(9001); + embeddingStore = VearchEmbeddingStore.builder() + .vearchConfig(this.vearchConfig) + .baseUrl(baseUrl) + .build(); + + vearchClient = VearchClient.builder() + .baseUrl(baseUrl) + .timeout(Duration.ofSeconds(60)) + .build(); + } + + @BeforeAll + static void beforeAll() { + vearch.setPortBindings(Arrays.asList("9001:9001", "8817:8817")); + vearch.start(); + } + + @AfterAll + static void afterAll() { + vearch.stop(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void clearStore() { + vearchClient.deleteSpace(databaseName, spaceName); + + vearchClient.createSpace(databaseName, CreateSpaceRequest.builder() + .name(spaceName) + .engine(vearchConfig.getSpaceEngine()) + .replicaNum(1) + .partitionNum(1) + .properties(vearchConfig.getProperties()) + .models(vearchConfig.getModelParams()) + .build()); + } + + @Test + void should_delete_space() { + embeddingStore.deleteSpace(); + List actual = vearchClient.listSpace(databaseName); + assertThat(actual.stream().map(ListSpaceResponse::getName)).doesNotContain(spaceName); + } + +} diff --git a/langchain4j-vearch/src/test/resources/config.toml b/langchain4j-vearch/src/test/resources/config.toml new file mode 100644 index 00000000000..5f83ac03cc0 --- /dev/null +++ b/langchain4j-vearch/src/test/resources/config.toml @@ -0,0 +1,86 @@ +[global] + # the name will validate join cluster by same name + name = "cbdb" + # specify which resources to use to create space + resource_name = "default" + # you data save to disk path ,If you are in a production environment, You'd better set absolute paths + data = ["datas/","datas1/"] + # log path , If you are in a production environment, You'd better set absolute paths + log = "logs/" + # default log type for any model + level = "debug" + # master <-> ps <-> router will use this key to send or receive data + signkey = "secret" + # skip auth for master and router + skip_auth = true + # tell Vearch whether it should manage it's own instance of etcd or not + self_manage_etcd = false + # automatically remove the failed node and recover when new nodes join + auto_recover_ps = false + # support access etcd basic auth,depend on self_manage_etcd = true + support_etcd_auth = false + # ensure leader-follow raft data synchronization is consistent + raft_consistent = false + +# self_manage_etcd = true,means manage etcd by yourself,need provide additional configuration +[etcd] + #etcd server ip or domain + address = ["127.0.0.1"] + # advertise_client_urls AND listen_client_urls + etcd_client_port = 2379 + # provider username and password,if you turn on auth + user_name = "root" + password = "" + +# if you are master you'd better set all config for router and ps and router and ps use default config it so cool +[[masters]] + #name machine name for cluster + name = "m1" + #ip or domain + address = "127.0.0.1" + # api port for http server + api_port = 8817 + # port for etcd server + etcd_port = 2378 + # listen_peer_urls List of comma separated URLs to listen on for peer traffic. + # advertise_peer_urls List of this member's peer URLs to advertise to the rest of the cluster. The URLs needed to be a comma-separated list. + etcd_peer_port = 2390 + # List of this member's client URLs to advertise to the public. + # The URLs needed to be a comma-separated list. + # advertise_client_urls AND listen_client_urls + etcd_client_port = 2370 + # init cluster state + cluster_state = "new" + pprof_port = 6062 + # monitor + monitor_port = 8818 + +[router] + # port for server + port = 9001 + # rpc_port = 9002 + pprof_port = 6061 + plugin_path = "plugin" + +[ps] + # port for server + rpc_port = 8081 + ps_heartbeat_timeout = 5 #seconds + #raft config begin + raft_heartbeat_port = 8898 + raft_replicate_port = 8899 + heartbeat-interval = 200 #ms + raft_retain_logs = 20000000 + raft_replica_concurrency = 1 + raft_snap_concurrency = 1 + raft_truncate_count = 500000 + #when behind leader this value,will stop the server for search + raft_diff_count = 10000 + # engine config + engine_dwpt_num = 8 + pprof_port = 6060 + # if set true , this ps only use in db meta config + private = false + # seconds + flush_time_interval = 600 + flush_count_threshold = 200000 diff --git a/pom.xml b/pom.xml index 7da45291b4d..d84ef33b941 100644 --- a/pom.xml +++ b/pom.xml @@ -46,6 +46,7 @@ langchain4j-vespa langchain4j-weaviate langchain4j-neo4j + langchain4j-vearch document-loaders/langchain4j-document-loader-amazon-s3