Skip to content

Commit

Permalink
fix: fix Spring AI embeddings for Hashes (List<Double> to byte[] conv…
Browse files Browse the repository at this point in the history
…ersion was buggy)
  • Loading branch information
bsbodden committed Jun 14, 2024
1 parent c5622ee commit d97c3e1
Show file tree
Hide file tree
Showing 29 changed files with 1,171 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import com.redis.om.spring.indexing.RediSearchIndexer;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.ops.search.SearchOperations;
import com.redis.om.spring.vectorize.FeatureExtractor;
import com.redis.om.spring.vectorize.Embedder;
import org.springframework.data.convert.CustomConversions;
import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.redis.connection.RedisConnection;
Expand Down Expand Up @@ -41,7 +41,7 @@ public class RedisEnhancedKeyValueAdapter extends RedisKeyValueAdapter {
private final RedisModulesOperations<String> modulesOperations;
private final RediSearchIndexer indexer;
private final EntityAuditor auditor;
private final FeatureExtractor featureExtractor;
private final Embedder embedder;
private final RedisOMProperties redisOMProperties;

/**
Expand All @@ -56,9 +56,9 @@ public RedisEnhancedKeyValueAdapter( //
RedisOperations<?, ?> redisOps, //
RedisModulesOperations<?> rmo, //
RediSearchIndexer indexer, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties redisOMProperties) {
this(redisOps, rmo, new RedisMappingContext(), indexer, featureExtractor, redisOMProperties);
this(redisOps, rmo, new RedisMappingContext(), indexer, embedder, redisOMProperties);
}

/**
Expand All @@ -75,9 +75,9 @@ public RedisEnhancedKeyValueAdapter( //
RedisModulesOperations<?> rmo, //
RedisMappingContext mappingContext, //
RediSearchIndexer indexer, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties redisOMProperties) {
this(redisOps, rmo, mappingContext, new RedisOMCustomConversions(), indexer, featureExtractor, redisOMProperties);
this(redisOps, rmo, mappingContext, new RedisOMCustomConversions(), indexer, embedder, redisOMProperties);
}

/**
Expand All @@ -96,7 +96,7 @@ public RedisEnhancedKeyValueAdapter( //
RedisMappingContext mappingContext, //
@Nullable CustomConversions customConversions, //
RediSearchIndexer indexer, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties redisOMProperties) {
super(redisOps, mappingContext, customConversions);

Expand All @@ -114,7 +114,7 @@ public RedisEnhancedKeyValueAdapter( //
this.modulesOperations = (RedisModulesOperations<String>) rmo;
this.indexer = indexer;
this.auditor = new EntityAuditor(this.redisOperations);
this.featureExtractor = featureExtractor;
this.embedder = embedder;
this.redisOMProperties = redisOMProperties;
}

Expand All @@ -141,7 +141,7 @@ public Object put(Object id, Object item, String keyspace) {
}
byte[] redisKey = createKey(sanitizeKeyspace(keyspace), idAsString);
auditor.processEntity(redisKey, item);
featureExtractor.processEntity(item);
embedder.processEntity(item);

rdo = new RedisData();
converter.write(item, rdo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.redis.om.spring.ops.json.JSONOperations;
import com.redis.om.spring.ops.search.SearchOperations;
import com.redis.om.spring.util.ObjectUtils;
import com.redis.om.spring.vectorize.FeatureExtractor;
import com.redis.om.spring.vectorize.Embedder;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.*;
Expand Down Expand Up @@ -53,7 +53,7 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter {
private final RediSearchIndexer indexer;
private final GsonBuilder gsonBuilder;
private final EntityAuditor auditor;
private final FeatureExtractor featureExtractor;
private final Embedder embedder;
private final RedisOMProperties redisOMProperties;

/**
Expand All @@ -72,7 +72,7 @@ public RedisJSONKeyValueAdapter( //
RedisMappingContext mappingContext, //
RediSearchIndexer indexer, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties redisOMProperties) {
super(redisOps, mappingContext, new RedisOMCustomConversions());
this.modulesOperations = (RedisModulesOperations<String>) rmo;
Expand All @@ -82,7 +82,7 @@ public RedisJSONKeyValueAdapter( //
this.indexer = indexer;
this.auditor = new EntityAuditor(this.redisOperations);
this.gsonBuilder = gsonBuilder;
this.featureExtractor = featureExtractor;
this.embedder = embedder;
this.redisOMProperties = redisOMProperties;
}

Expand All @@ -102,7 +102,7 @@ public Object put(Object id, Object item, String keyspace) {

processVersion(key, item);
auditor.processEntity(key, item);
featureExtractor.processEntity(item);
embedder.processEntity(item);
Optional<Long> maybeTtl = getTTLForEntity(item);

ops.set(key, item);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import com.redis.om.spring.search.stream.EntityStream;
import com.redis.om.spring.search.stream.EntityStreamImpl;
import com.redis.om.spring.serialization.gson.*;
import com.redis.om.spring.vectorize.DefaultFeatureExtractor;
import com.redis.om.spring.vectorize.FeatureExtractor;
import com.redis.om.spring.vectorize.NoopFeatureExtractor;
import com.redis.om.spring.vectorize.DefaultEmbedder;
import com.redis.om.spring.vectorize.Embedder;
import com.redis.om.spring.vectorize.NoopEmbedder;
import com.redis.om.spring.vectorize.face.FaceDetectionTranslator;
import com.redis.om.spring.vectorize.face.FaceFeatureTranslator;
import org.apache.commons.lang3.ObjectUtils;
Expand Down Expand Up @@ -562,7 +562,7 @@ BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel(RedisOMProperties properti
}

@Bean(name = "featureExtractor")
public FeatureExtractor featureExtractor(
public Embedder featureExtractor(
@Nullable @Qualifier("djlImageEmbeddingModel") ZooModel<Image, byte[]> imageEmbeddingModel,
@Nullable @Qualifier("djlFaceEmbeddingModel") ZooModel<Image, float[]> faceEmbeddingModel,
@Nullable @Qualifier("djlImageFactory") ImageFactory imageFactory,
Expand All @@ -574,10 +574,10 @@ public FeatureExtractor featureExtractor(
@Nullable BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel, RedisOMProperties properties,
ApplicationContext ac) {
return properties.getDjl().isEnabled() ?
new DefaultFeatureExtractor(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline,
new DefaultEmbedder(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline,
sentenceTokenizer, openAITextVectorizer, azureOpenAIClient, vertexAiPaLm2EmbeddingModel,
bedrockCohereEmbeddingModel, bedrockTitanEmbeddingModel, properties) :
new NoopFeatureExtractor();
new NoopEmbedder();
}

@Bean(name = "redisJSONKeyValueAdapter")
Expand All @@ -588,9 +588,9 @@ RedisJSONKeyValueAdapter getRedisJSONKeyValueAdapter( //
RediSearchIndexer indexer, //
@Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, //
RedisOMProperties properties, //
@Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) {
@Nullable @Qualifier("featureExtractor") Embedder embedder) {
return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder,
featureExtractor, properties);
embedder, properties);
}

@Bean(name = "redisJSONKeyValueTemplate")
Expand All @@ -601,10 +601,10 @@ public CustomRedisKeyValueTemplate getRedisJSONKeyValueTemplate( //
RediSearchIndexer indexer, //
@Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, //
RedisOMProperties properties, //
@Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) {
@Nullable @Qualifier("featureExtractor") Embedder embedder) {
return new CustomRedisKeyValueTemplate(
new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder,
featureExtractor, properties), mappingContext);
new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, embedder,
properties), mappingContext);
}

@Bean(name = "redisCustomKeyValueTemplate")
Expand All @@ -614,9 +614,9 @@ public CustomRedisKeyValueTemplate getKeyValueTemplate( //
RedisMappingContext mappingContext, //
RediSearchIndexer indexer, //
RedisOMProperties properties, //
@Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) {
@Nullable @Qualifier("featureExtractor") Embedder embedder) {
return new CustomRedisKeyValueTemplate(
new RedisEnhancedKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, featureExtractor,
new RedisEnhancedKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, embedder,
properties), //
mappingContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.redis.om.spring.indexing.RediSearchIndexer;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.repository.query.RediSearchQuery;
import com.redis.om.spring.vectorize.FeatureExtractor;
import com.redis.om.spring.vectorize.Embedder;
import org.springframework.beans.BeanUtils;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.keyvalue.repository.query.KeyValuePartTreeQuery;
Expand Down Expand Up @@ -42,7 +42,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory {
private final RediSearchIndexer indexer;
private final GsonBuilder gsonBuilder;
private final RedisMappingContext mappingContext;
private final FeatureExtractor featureExtractor;
private final Embedder embedder;
private final RedisOMProperties properties;

/**
Expand All @@ -54,7 +54,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory {
* @param indexer must not be {@literal null}.
* @param mappingContext must not be {@literal null}.
* @param gsonBuilder must not be {@literal null}.
* @param featureExtractor must not be {@literal null}.
* @param embedder must not be {@literal null}.
* @param properties must not be {@literal null}.
*/
public RedisDocumentRepositoryFactory( //
Expand All @@ -63,7 +63,7 @@ public RedisDocumentRepositoryFactory( //
RediSearchIndexer indexer, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties properties //
) {
this( //
Expand All @@ -73,7 +73,7 @@ public RedisDocumentRepositoryFactory( //
DEFAULT_QUERY_CREATOR, //
mappingContext, //
gsonBuilder, //
featureExtractor, //
embedder, //
properties //
); //
}
Expand All @@ -96,7 +96,7 @@ public RedisDocumentRepositoryFactory( //
Class<? extends AbstractQueryCreator<?, ?>> queryCreator, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties properties //
) {
this( //
Expand All @@ -107,7 +107,7 @@ public RedisDocumentRepositoryFactory( //
RediSearchQuery.class, //
mappingContext, //
gsonBuilder, //
featureExtractor, //
embedder, //
properties //
);
}
Expand All @@ -123,7 +123,7 @@ public RedisDocumentRepositoryFactory( //
* @param repositoryQueryType must not be {@literal null}.
* @param mappingContext must not be {@literal null}.
* @param gsonBuilder must not be {@literal null}.
* @param featureExtractor must not be {@literal null}.
* @param embedder must not be {@literal null}.
* @param properties must not be {@literal null}.
*/
public RedisDocumentRepositoryFactory( //
Expand All @@ -134,7 +134,7 @@ public RedisDocumentRepositoryFactory( //
Class<? extends RepositoryQuery> repositoryQueryType, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
Embedder embedder, //
RedisOMProperties properties //
) {

Expand All @@ -145,7 +145,7 @@ public RedisDocumentRepositoryFactory( //
Assert.notNull(indexer, "RediSearchIndexer must not be null!");
Assert.notNull(queryCreator, "Query creator type must not be null!");
Assert.notNull(repositoryQueryType, "RepositoryQueryType type must not be null!");
Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!");
Assert.notNull(embedder, "FeatureExtractor type must not be null!");
Assert.notNull(properties, "RedisOMSpringProperties type must not be null!");

this.keyValueOperations = keyValueOperations;
Expand All @@ -155,7 +155,7 @@ public RedisDocumentRepositoryFactory( //
this.repositoryQueryType = repositoryQueryType;
this.mappingContext = mappingContext;
this.gsonBuilder = gsonBuilder;
this.featureExtractor = featureExtractor;
this.embedder = embedder;
this.properties = properties;
}

Expand All @@ -170,7 +170,7 @@ protected Object getTargetRepository(RepositoryInformation repositoryInformation
indexer, //
mappingContext, //
gsonBuilder, //
featureExtractor, //
embedder, //
properties //
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.redis.om.spring.RedisOMProperties;
import com.redis.om.spring.indexing.RediSearchIndexer;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.vectorize.FeatureExtractor;
import com.redis.om.spring.vectorize.Embedder;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.keyvalue.repository.support.KeyValueRepositoryFactoryBean;
Expand All @@ -27,7 +27,7 @@ public class RedisDocumentRepositoryFactoryBean<T extends Repository<S, ID>, S,
@Autowired
private GsonBuilder gsonBuilder;
@Autowired
private @Nullable FeatureExtractor featureExtractor;
private @Nullable Embedder embedder;
@Autowired
private RedisOMProperties properties;

Expand All @@ -49,7 +49,7 @@ protected final RedisDocumentRepositoryFactory createRepositoryFactory( //
Class<? extends RepositoryQuery> repositoryQueryType //
) {
return new RedisDocumentRepositoryFactory(operations, rmo, indexer, queryCreator, repositoryQueryType,
this.mappingContext, this.gsonBuilder, this.featureExtractor, this.properties);
this.mappingContext, this.gsonBuilder, this.embedder, this.properties);
}

@Override
Expand Down
Loading

0 comments on commit d97c3e1

Please sign in to comment.