From 9c62fe9d3bdafd377339ec0fb45027f1f4e1dbb7 Mon Sep 17 00:00:00 2001 From: deep-learning-dynamo Date: Tue, 26 Sep 2023 10:19:48 +0200 Subject: [PATCH] Removed dynamic loading from RedisEmbeddingStore --- langchain4j-redis/pom.xml | 28 +-- .../store/embedding/redis/MetricType.java | 6 +- ...toreImpl.java => RedisEmbeddingStore.java} | 162 +++++++++++++---- .../store/embedding/redis/RedisSchema.java | 12 +- .../redis/RedisEmbeddingStoreImplTest.java | 85 --------- .../redis/RedisEmbeddingStoreTest.java | 96 ++++++++++ .../embedding/redis/RedisEmbeddingStore.java | 169 ------------------ 7 files changed, 240 insertions(+), 318 deletions(-) rename langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/{RedisEmbeddingStoreImpl.java => RedisEmbeddingStore.java} (50%) delete mode 100644 langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImplTest.java create mode 100644 langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java delete mode 100644 langchain4j/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java diff --git a/langchain4j-redis/pom.xml b/langchain4j-redis/pom.xml index f0558fa3ab3..66c5fe2f141 100644 --- a/langchain4j-redis/pom.xml +++ b/langchain4j-redis/pom.xml @@ -14,29 +14,26 @@ jar LangChain4j integration with Redis - Uses jedis library which has a MIT license: - https://github.com/redis/jedis/blob/master/LICENSE - + dev.langchain4j langchain4j-core ${project.version} + redis.clients jedis - - com.google.code.gson - gson - + org.projectlombok lombok provided + org.slf4j slf4j-api @@ -47,16 +44,19 @@ junit-jupiter-engine test + org.mockito mockito-core test + org.mockito mockito-junit-jupiter test + org.tinylog tinylog-impl @@ -68,19 +68,7 @@ slf4j-tinylog test - - - - - org.honton.chas - license-maven-plugin - - - true - - - - + \ No newline at end of file diff --git a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/MetricType.java b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/MetricType.java index 7c42c0e69e4..9c228825e45 100644 --- a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/MetricType.java +++ b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/MetricType.java @@ -1,18 +1,20 @@ package dev.langchain4j.store.embedding.redis; /** - * Redis vector field distance Metric + * Similarity metric used by Redis */ -public enum MetricType { +enum MetricType { /** * cosine similarity */ COSINE, + /** * inner product */ IP, + /** * euclidean distance */ diff --git a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImpl.java b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java similarity index 50% rename from langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImpl.java rename to langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java index e3faba6b787..3f74ae16f3d 100644 --- a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImpl.java +++ b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java @@ -4,7 +4,6 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.internal.ValidationUtils; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import org.slf4j.Logger; @@ -19,32 +18,49 @@ import static dev.langchain4j.internal.Utils.isCollectionEmpty; import static dev.langchain4j.internal.Utils.randomUUID; -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.store.embedding.redis.RedisSchema.SCORE_FIELD_NAME; +import static java.lang.String.format; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; +import static redis.clients.jedis.search.IndexDefinition.Type.JSON; import static redis.clients.jedis.search.RediSearchUtil.ToByteArray; /** - * Redis Embedding Store Implementation + * Represents a Redis index as an embedding store. + * Current implementation assumes the index uses the cosine distance metric. */ -public class RedisEmbeddingStoreImpl implements EmbeddingStore { +public class RedisEmbeddingStore implements EmbeddingStore { - private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStoreImpl.class); + private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class); private static final Gson GSON = new Gson(); private final JedisPooled client; private final RedisSchema schema; - public RedisEmbeddingStoreImpl(String host, - Integer port, - String user, - String password, - Integer dimension, - List metadataFieldsName) { - host = ensureNotNull(host, "url"); + /** + * Creates an instance of RedisEmbeddingStore + * + * @param host Redis Stack Server host + * @param port Redis Stack Server port + * @param user Redis Stack username (optional) + * @param password Redis Stack password (optional) + * @param dimension embedding vector dimension + * @param metadataFieldsName metadata fields name (optional) + */ + public RedisEmbeddingStore(String host, + Integer port, + String user, + String password, + Integer dimension, + List metadataFieldsName) { + ensureNotBlank(host, "host"); ensureNotNull(port, "port"); ensureNotNull(dimension, "dimension"); - client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password); - schema = RedisSchema.builder() + this.client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password); + this.schema = RedisSchema.builder() .dimension(dimension) .metadataFieldsName(metadataFieldsName) .build(); @@ -77,7 +93,7 @@ public String add(Embedding embedding, TextSegment textSegment) { public List addAll(List embeddings) { List ids = embeddings.stream() .map(ignored -> randomUUID()) - .collect(Collectors.toList()); + .collect(toList()); addAllInternal(ids, embeddings, null); return ids; } @@ -86,7 +102,7 @@ public List addAll(List embeddings) { public List addAll(List embeddings, List embedded) { List ids = embeddings.stream() .map(ignored -> randomUUID()) - .collect(Collectors.toList()); + .collect(toList()); addAllInternal(ids, embeddings, embedded); return ids; } @@ -96,11 +112,11 @@ public List> findRelevant(Embedding referenceEmbeddi // Using KNN query on @vector field String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]"; List returnFields = new ArrayList<>(schema.getMetadataFieldsName()); - returnFields.addAll(Arrays.asList(schema.getVectorFieldName(), schema.getScalarFieldName(), RedisSchema.SCORE_FIELD_NAME)); - Query query = new Query(String.format(queryTemplate, maxResults, schema.getVectorFieldName(), RedisSchema.SCORE_FIELD_NAME)) + returnFields.addAll(asList(schema.getVectorFieldName(), schema.getScalarFieldName(), SCORE_FIELD_NAME)); + Query query = new Query(format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME)) .addParam("BLOB", ToByteArray(referenceEmbedding.vector())) .returnFields(returnFields.toArray(new String[0])) - .setSortBy(RedisSchema.SCORE_FIELD_NAME, true) + .setSortBy(SCORE_FIELD_NAME, true) .dialect(2); SearchResult result = client.ftSearch(schema.getIndexName(), query); @@ -110,7 +126,7 @@ public List> findRelevant(Embedding referenceEmbeddi } private void createIndex(String indexName) { - IndexDefinition indexDefinition = new IndexDefinition(IndexDefinition.Type.JSON); + IndexDefinition indexDefinition = new IndexDefinition(JSON); indexDefinition.setPrefixes(schema.getPrefix()); String res = client.ftCreate(indexName, FTCreateParams.createParams() .on(IndexDataType.JSON) @@ -129,7 +145,7 @@ private boolean isIndexExist(String indexName) { } private void addInternal(String id, Embedding embedding, TextSegment embedded) { - addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded)); + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); } private void addAllInternal(List ids, List embeddings, List embedded) { @@ -137,8 +153,8 @@ private void addAllInternal(List ids, List embeddings, List> toEmbeddingMatch(List docume return new ArrayList<>(); } - return documents.stream().map(document -> { - double score = (2 - Double.parseDouble(document.getString(RedisSchema.SCORE_FIELD_NAME))) / 2; - String id = document.getId().substring(schema.getPrefix().length()); - String text = document.hasProperty(schema.getScalarFieldName()) ? document.getString(schema.getScalarFieldName()) : null; - TextSegment embedded = null; - if (text != null) { - List metadataFieldsName = schema.getMetadataFieldsName(); - Map metadata = metadataFieldsName.stream() - .filter(document::hasProperty) - .collect(Collectors.toMap(metadataFieldName -> metadataFieldName, document::getString)); - embedded = new TextSegment(text, new Metadata(metadata)); - } - Embedding embedding = new Embedding(GSON.fromJson(document.getString(schema.getVectorFieldName()), float[].class)); - return new EmbeddingMatch<>(score, id, embedding, embedded); - }).filter(embeddingMatch -> embeddingMatch.score() >= minScore).collect(Collectors.toList()); + return documents.stream() + .map(document -> { + double score = (2 - Double.parseDouble(document.getString(SCORE_FIELD_NAME))) / 2; + String id = document.getId().substring(schema.getPrefix().length()); + String text = document.hasProperty(schema.getScalarFieldName()) ? document.getString(schema.getScalarFieldName()) : null; + TextSegment embedded = null; + if (text != null) { + List metadataFieldsName = schema.getMetadataFieldsName(); + Map metadata = metadataFieldsName.stream() + .filter(document::hasProperty) + .collect(Collectors.toMap(metadataFieldName -> metadataFieldName, document::getString)); + embedded = new TextSegment(text, new Metadata(metadata)); + } + Embedding embedding = new Embedding(GSON.fromJson(document.getString(schema.getVectorFieldName()), float[].class)); + return new EmbeddingMatch<>(score, id, embedding, embedded); + }) + .filter(embeddingMatch -> embeddingMatch.score() >= minScore) + .collect(toList()); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String host; + private Integer port; + private String user; + private String password; + private Integer dimension; + private List metadataFieldsName; + + /** + * @param host Redis Stack host + */ + public Builder host(String host) { + this.host = host; + return this; + } + + /** + * @param port Redis Stack port + */ + public Builder port(Integer port) { + this.port = port; + return this; + } + + /** + * @param user Redis Stack username (optional) + */ + public Builder user(String user) { + this.user = user; + return this; + } + + /** + * @param password Redis Stack password (optional) + */ + public Builder password(String password) { + this.password = password; + return this; + } + + /** + * @param dimension embedding vector dimension + * @return builder + */ + public Builder dimension(Integer dimension) { + this.dimension = dimension; + return this; + } + + /** + * @param metadataFieldsName metadata fields name (optional) + */ + public Builder metadataFieldsName(List metadataFieldsName) { + this.metadataFieldsName = metadataFieldsName; + return this; + } + + public RedisEmbeddingStore build() { + return new RedisEmbeddingStore(host, port, user, password, dimension, metadataFieldsName); + } } } diff --git a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisSchema.java b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisSchema.java index 5c0d39897ce..a09c38c71f2 100644 --- a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisSchema.java +++ b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisSchema.java @@ -5,23 +5,27 @@ import redis.clients.jedis.search.schemafields.SchemaField; import redis.clients.jedis.search.schemafields.TextField; import redis.clients.jedis.search.schemafields.VectorField; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import static dev.langchain4j.store.embedding.redis.MetricType.COSINE; +import static redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm.HNSW; + /** * Redis Schema Description */ @Builder @AllArgsConstructor -public class RedisSchema { +class RedisSchema { public static final String SCORE_FIELD_NAME = "vector_score"; private static final String JSON_PATH_PREFIX = "$."; - private static final VectorField.VectorAlgorithm DEFAULT_VECTOR_ALGORITHM = VectorField.VectorAlgorithm.HNSW; - private static final MetricType DEFAULT_METRIC_TYPE = MetricType.COSINE; + private static final VectorAlgorithm DEFAULT_VECTOR_ALGORITHM = HNSW; + private static final MetricType DEFAULT_METRIC_TYPE = COSINE; /* Redis schema field settings */ @@ -39,7 +43,7 @@ public class RedisSchema { /* Vector field settings */ @Builder.Default - private VectorField.VectorAlgorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; + private VectorAlgorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; private int dimension; @Builder.Default private MetricType metricType = DEFAULT_METRIC_TYPE; diff --git a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImplTest.java b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImplTest.java deleted file mode 100644 index 261019a900e..00000000000 --- a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreImplTest.java +++ /dev/null @@ -1,85 +0,0 @@ -package dev.langchain4j.store.embedding.redis; - -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.internal.Utils; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingStore; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * disabled default, because this need local deployment of Redis - */ -@Disabled -class RedisEmbeddingStoreImplTest { - - private final EmbeddingStore store = new RedisEmbeddingStoreImpl("localhost", 6379, - "default", "password", 4, Collections.singletonList("field")); - - @Test - void testAdd() { - // test add without id - String id = store.add(Embedding.from(Arrays.asList(0.50f, 0.85f, 0.760f, 0.24f)), - TextSegment.from("test string", Metadata.metadata("field", "value"))); - System.out.println("id=" + id); - - // test add with id - String selfId = Utils.randomUUID(); - store.add(selfId, Embedding.from(Arrays.asList(0.80f, 0.45f, 0.89f, 0.24f))); - System.out.println("id=" + selfId); - } - - @Test - void testAddAll() { - // test add All Method without embedded - List ids = store.addAll(Arrays.asList( - Embedding.from(Arrays.asList(0.3f, 0.87f, 0.90f, 0.24f)), - Embedding.from(Arrays.asList(0.54f, 0.34f, 0.67f, 0.24f)), - Embedding.from(Arrays.asList(0.80f, 0.45f, 0.779f, 0.5556f)) - )); - System.out.println("ids=" + ids); - - // test add all method with embedded - ids = store.addAll(Arrays.asList( - Embedding.from(Arrays.asList(0.3f, 0.87f, 0.90f, 0.24f)), - Embedding.from(Arrays.asList(0.54f, 0.34f, 0.67f, 0.24f)), - Embedding.from(Arrays.asList(0.80f, 0.45f, 0.779f, 0.5556f)) - ), Arrays.asList( - TextSegment.from("testString1", Metadata.metadata("field", "value1")), - TextSegment.from("testString2", Metadata.metadata("field", "value2")), - TextSegment.from("testingString3", Metadata.metadata("field", "value3")) - )); - System.out.println("ids=" + ids); - } - - @Test - void testAddEmpty() { - // see log - store.addAll(Collections.emptyList()); - } - - @Test - void testFindRelevant() { - List> res = store.findRelevant(Embedding.from(Arrays.asList(0.80f, 0.45f, 0.89f, 0.24f)), 5); - res.forEach(System.out::println); - } - - @Test - void testScore() { - String id = store.add(Embedding.from(Arrays.asList(0.50f, 0.85f, 0.760f, 0.24f)), - TextSegment.from("test string", Metadata.metadata("field", "value"))); - System.out.println("id=" + id); - - // use the same embedding to search - List> res = store.findRelevant(Embedding.from(Arrays.asList(0.50f, 0.85f, 0.760f, 0.24f)), 1); - res.forEach(System.out::println); - - // the result embeddingMatch score is 5.96046447754E-8, but expected is 1 because they are same vectors. - } -} diff --git a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java new file mode 100644 index 00000000000..ce02422159f --- /dev/null +++ b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java @@ -0,0 +1,96 @@ +package dev.langchain4j.store.embedding.redis; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + +@Disabled("needs Redis running locally") +class RedisEmbeddingStoreTest { + + /** + * First start Redis locally: + * docker pull redis/redis-stack:latest + * docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest + */ + + private final EmbeddingStore store = new RedisEmbeddingStore( + "localhost", + 6379, + "default", + "password", + 4, + singletonList("field") + ); + + @Test + void testAdd() { + // test add without id + String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), + TextSegment.from("test string", Metadata.from("field", "value"))); + System.out.println("id=" + id); + + // test add with id + String selfId = Utils.randomUUID(); + store.add(selfId, Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f))); + System.out.println("id=" + selfId); + } + + @Test + void testAddAll() { + // test add All Method without embedded + List ids = store.addAll(asList( + Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), + Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)), + Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) + )); + System.out.println("ids=" + ids); + + // test add all method with embedded + ids = store.addAll(asList( + Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), + Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)), + Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) + ), asList( + TextSegment.from("testString1", Metadata.from("field", "value1")), + TextSegment.from("testString2", Metadata.from("field", "value2")), + TextSegment.from("testingString3", Metadata.from("field", "value3")) + )); + System.out.println("ids=" + ids); + } + + @Test + void testAddEmpty() { + // see log + store.addAll(emptyList()); + } + + @Test + void testFindRelevant() { + List> res = store.findRelevant(Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)), 5); + res.forEach(System.out::println); + } + + @Test + void testScore() { + String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), + TextSegment.from("test string", Metadata.from("field", "value"))); + System.out.println("id=" + id); + + // use the same embedding to search + List> res = store.findRelevant(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), 1); + res.forEach(System.out::println); + + // the result embeddingMatch score is 5.96046447754E-8, but expected is 1 because they are same vectors. + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java deleted file mode 100644 index a6b49ad59c9..00000000000 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java +++ /dev/null @@ -1,169 +0,0 @@ -package dev.langchain4j.store.embedding.redis; - -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingStore; - -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.List; - -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; - -/** - * Represents a Redis index as an embedding store. - * Current implementation assumes the index uses the cosine distance metric. - * To use RedisEmbeddingStore, please add the "langchain4j-redis" dependency to your project. - */ -public class RedisEmbeddingStore implements EmbeddingStore { - - private final EmbeddingStore implementation; - - /** - * Creates an instance of RedisEmbeddingStore - * - * @param host Redis Stack Server host - * @param port Redis Stack Server Port - * @param user Redis Stack username - * @param password Redis Stack password - * @param dimension vector dimension - * @param metadataFieldsName metadata fields name - */ - public RedisEmbeddingStore(String host, Integer port, String user, String password, Integer dimension, List metadataFieldsName) { - ensureNotNull(port, "port"); - ensureNotNull(dimension, "dimension"); - try { - implementation = loadDynamically( - "dev.langchain4j.store.embedding.redis.RedisEmbeddingStoreImpl", - host, port, user, password, dimension, metadataFieldsName - ); - } catch (ClassNotFoundException e) { - throw new RuntimeException(getMessage(), e); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static String getMessage() { - return "To use RedisEmbeddingStore, please add the following dependency to your project:\n\n" - + "Maven:\n" - + "\n" + - " dev.langchain4j\n" + - " langchain4j-redis\n" + - " 0.22.0\n" + - "\n\n" - + "Gradle:\n" - + "implementation 'dev.langchain4j:langchain4j-redis:0.22.0'\n"; - } - - private static EmbeddingStore loadDynamically(String implementationClassName, String host, int port, - String user, String password, int dimension, List metadataFieldsName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException { - Class implementationClass = Class.forName(implementationClassName); - Class[] constructorParameterTypes = new Class[]{String.class, Integer.class, String.class, String.class, Integer.class, List.class}; - Constructor constructor = implementationClass.getConstructor(constructorParameterTypes); - return (EmbeddingStore) constructor.newInstance(host, port, user, password, dimension, metadataFieldsName); - } - - public static Builder builder() { - return new Builder(); - } - - @Override - public String add(Embedding embedding) { - return implementation.add(embedding); - } - - @Override - public void add(String id, Embedding embedding) { - implementation.add(id, embedding); - } - - @Override - public String add(Embedding embedding, TextSegment textSegment) { - return implementation.add(embedding, textSegment); - } - - @Override - public List addAll(List embeddings) { - return implementation.addAll(embeddings); - } - - @Override - public List addAll(List embeddings, List textSegments) { - return implementation.addAll(embeddings, textSegments); - } - - @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults) { - return implementation.findRelevant(referenceEmbedding, maxResults); - } - - @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { - return implementation.findRelevant(referenceEmbedding, maxResults, minScore); - } - - public static class Builder { - - private String host; - private Integer port; - private String user; - private String password; - private Integer dimension; - private List metadataFieldsName; - - /** - * @param host Redis Stack host - */ - public Builder host(String host) { - this.host = host; - return this; - } - - /** - * @param port Redis Stack port - */ - public Builder port(Integer port) { - this.port = port; - return this; - } - - /** - * @param user Redis Stack username - */ - public Builder user(String user) { - this.user = user; - return this; - } - - /** - * @param password Redis Stack password - */ - public Builder password(String password) { - this.password = password; - return this; - } - - /** - * @param dimension vector dimension - * @return builder - */ - public Builder dimension(Integer dimension) { - this.dimension = dimension; - return this; - } - - /** - * @param metadataFieldsName metadata fields name - */ - public Builder metadataFieldsName(List metadataFieldsName) { - this.metadataFieldsName = metadataFieldsName; - return this; - } - - public RedisEmbeddingStore build() { - return new RedisEmbeddingStore(host, port, user, password, dimension, metadataFieldsName); - } - } -}