Skip to content

Commit

Permalink
Removed dynamic loading from RedisEmbeddingStore
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-learning-dynamo committed Sep 26, 2023
1 parent 887120b commit 9c62fe9
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 318 deletions.
28 changes: 8 additions & 20 deletions langchain4j-redis/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,26 @@
<packaging>jar</packaging>

<name>LangChain4j integration with Redis</name>
<description>Uses jedis library which has a MIT license:
https://github.com/redis/jedis/blob/master/LICENSE
</description>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Expand All @@ -47,16 +44,19 @@
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
Expand All @@ -68,19 +68,7 @@
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.honton.chas</groupId>
<artifactId>license-maven-plugin</artifactId>
<configuration>
<!-- jedis use MIT license but is not acceptable(I don't know why) -->
<skipCompliance>true</skipCompliance>
</configuration>
</plugin>
</plugins>
</build>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
package dev.langchain4j.store.embedding.redis;

/**
* Redis vector field distance Metric
* Similarity metric used by Redis
*/
public enum MetricType {
enum MetricType {

/**
* cosine similarity
*/
COSINE,

/**
* inner product
*/
IP,

/**
* euclidean distance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.slf4j.Logger;
Expand All @@ -19,32 +18,49 @@

import static dev.langchain4j.internal.Utils.isCollectionEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.*;
import static dev.langchain4j.store.embedding.redis.RedisSchema.SCORE_FIELD_NAME;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static redis.clients.jedis.search.IndexDefinition.Type.JSON;
import static redis.clients.jedis.search.RediSearchUtil.ToByteArray;

/**
* Redis Embedding Store Implementation
* Represents a <a href="https://redis.io/">Redis</a> index as an embedding store.
* Current implementation assumes the index uses the cosine distance metric.
*/
public class RedisEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {

private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStoreImpl.class);
private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
private static final Gson GSON = new Gson();

private final JedisPooled client;
private final RedisSchema schema;

public RedisEmbeddingStoreImpl(String host,
Integer port,
String user,
String password,
Integer dimension,
List<String> metadataFieldsName) {
host = ensureNotNull(host, "url");
/**
* Creates an instance of RedisEmbeddingStore
*
* @param host Redis Stack Server host
* @param port Redis Stack Server port
* @param user Redis Stack username (optional)
* @param password Redis Stack password (optional)
* @param dimension embedding vector dimension
* @param metadataFieldsName metadata fields name (optional)
*/
public RedisEmbeddingStore(String host,
Integer port,
String user,
String password,
Integer dimension,
List<String> metadataFieldsName) {
ensureNotBlank(host, "host");
ensureNotNull(port, "port");
ensureNotNull(dimension, "dimension");

client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password);
schema = RedisSchema.builder()
this.client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password);
this.schema = RedisSchema.builder()
.dimension(dimension)
.metadataFieldsName(metadataFieldsName)
.build();
Expand Down Expand Up @@ -77,7 +93,7 @@ public String add(Embedding embedding, TextSegment textSegment) {
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(Collectors.toList());
.collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
Expand All @@ -86,7 +102,7 @@ public List<String> addAll(List<Embedding> embeddings) {
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = embeddings.stream()
.map(ignored -> randomUUID())
.collect(Collectors.toList());
.collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
Expand All @@ -96,11 +112,11 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
// Using KNN query on @vector field
String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
List<String> returnFields = new ArrayList<>(schema.getMetadataFieldsName());
returnFields.addAll(Arrays.asList(schema.getVectorFieldName(), schema.getScalarFieldName(), RedisSchema.SCORE_FIELD_NAME));
Query query = new Query(String.format(queryTemplate, maxResults, schema.getVectorFieldName(), RedisSchema.SCORE_FIELD_NAME))
returnFields.addAll(asList(schema.getVectorFieldName(), schema.getScalarFieldName(), SCORE_FIELD_NAME));
Query query = new Query(format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME))
.addParam("BLOB", ToByteArray(referenceEmbedding.vector()))
.returnFields(returnFields.toArray(new String[0]))
.setSortBy(RedisSchema.SCORE_FIELD_NAME, true)
.setSortBy(SCORE_FIELD_NAME, true)
.dialect(2);

SearchResult result = client.ftSearch(schema.getIndexName(), query);
Expand All @@ -110,7 +126,7 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
}

private void createIndex(String indexName) {
IndexDefinition indexDefinition = new IndexDefinition(IndexDefinition.Type.JSON);
IndexDefinition indexDefinition = new IndexDefinition(JSON);
indexDefinition.setPrefixes(schema.getPrefix());
String res = client.ftCreate(indexName, FTCreateParams.createParams()
.on(IndexDataType.JSON)
Expand All @@ -129,16 +145,16 @@ private boolean isIndexExist(String indexName) {
}

private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded));
}

private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isCollectionEmpty(ids) || isCollectionEmpty(embeddings)) {
log.info("do not add empty embeddings to redis");
return;
}
ValidationUtils.ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ValidationUtils.ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");

Pipeline pipeline = client.pipelined();

Expand Down Expand Up @@ -172,20 +188,90 @@ private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> docume
return new ArrayList<>();
}

return documents.stream().map(document -> {
double score = (2 - Double.parseDouble(document.getString(RedisSchema.SCORE_FIELD_NAME))) / 2;
String id = document.getId().substring(schema.getPrefix().length());
String text = document.hasProperty(schema.getScalarFieldName()) ? document.getString(schema.getScalarFieldName()) : null;
TextSegment embedded = null;
if (text != null) {
List<String> metadataFieldsName = schema.getMetadataFieldsName();
Map<String, String> metadata = metadataFieldsName.stream()
.filter(document::hasProperty)
.collect(Collectors.toMap(metadataFieldName -> metadataFieldName, document::getString));
embedded = new TextSegment(text, new Metadata(metadata));
}
Embedding embedding = new Embedding(GSON.fromJson(document.getString(schema.getVectorFieldName()), float[].class));
return new EmbeddingMatch<>(score, id, embedding, embedded);
}).filter(embeddingMatch -> embeddingMatch.score() >= minScore).collect(Collectors.toList());
return documents.stream()
.map(document -> {
double score = (2 - Double.parseDouble(document.getString(SCORE_FIELD_NAME))) / 2;
String id = document.getId().substring(schema.getPrefix().length());
String text = document.hasProperty(schema.getScalarFieldName()) ? document.getString(schema.getScalarFieldName()) : null;
TextSegment embedded = null;
if (text != null) {
List<String> metadataFieldsName = schema.getMetadataFieldsName();
Map<String, String> metadata = metadataFieldsName.stream()
.filter(document::hasProperty)
.collect(Collectors.toMap(metadataFieldName -> metadataFieldName, document::getString));
embedded = new TextSegment(text, new Metadata(metadata));
}
Embedding embedding = new Embedding(GSON.fromJson(document.getString(schema.getVectorFieldName()), float[].class));
return new EmbeddingMatch<>(score, id, embedding, embedded);
})
.filter(embeddingMatch -> embeddingMatch.score() >= minScore)
.collect(toList());
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private String host;
private Integer port;
private String user;
private String password;
private Integer dimension;
private List<String> metadataFieldsName;

/**
* @param host Redis Stack host
*/
public Builder host(String host) {
this.host = host;
return this;
}

/**
* @param port Redis Stack port
*/
public Builder port(Integer port) {
this.port = port;
return this;
}

/**
* @param user Redis Stack username (optional)
*/
public Builder user(String user) {
this.user = user;
return this;
}

/**
* @param password Redis Stack password (optional)
*/
public Builder password(String password) {
this.password = password;
return this;
}

/**
* @param dimension embedding vector dimension
* @return builder
*/
public Builder dimension(Integer dimension) {
this.dimension = dimension;
return this;
}

/**
* @param metadataFieldsName metadata fields name (optional)
*/
public Builder metadataFieldsName(List<String> metadataFieldsName) {
this.metadataFieldsName = metadataFieldsName;
return this;
}

public RedisEmbeddingStore build() {
return new RedisEmbeddingStore(host, port, user, password, dimension, metadataFieldsName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.search.schemafields.VectorField;
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.store.embedding.redis.MetricType.COSINE;
import static redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm.HNSW;

/**
* Redis Schema Description
*/
@Builder
@AllArgsConstructor
public class RedisSchema {
class RedisSchema {

public static final String SCORE_FIELD_NAME = "vector_score";
private static final String JSON_PATH_PREFIX = "$.";
private static final VectorField.VectorAlgorithm DEFAULT_VECTOR_ALGORITHM = VectorField.VectorAlgorithm.HNSW;
private static final MetricType DEFAULT_METRIC_TYPE = MetricType.COSINE;
private static final VectorAlgorithm DEFAULT_VECTOR_ALGORITHM = HNSW;
private static final MetricType DEFAULT_METRIC_TYPE = COSINE;

/* Redis schema field settings */

Expand All @@ -39,7 +43,7 @@ public class RedisSchema {
/* Vector field settings */

@Builder.Default
private VectorField.VectorAlgorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
private VectorAlgorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
private int dimension;
@Builder.Default
private MetricType metricType = DEFAULT_METRIC_TYPE;
Expand Down
Loading

0 comments on commit 9c62fe9

Please sign in to comment.