Skip to content

re-enable tests and upgrade to Spring AI 1.0.0 #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ docs/build/
/out/
build/
.gradle/
compile_debug.log
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ elementaryVersion = 2.0.1
gsonVersion = 2.10.1
djlStarterVersion = 0.26
djlVersion = 0.30.0
springAiVersion = 1.0.0-M8
springAiVersion = 1.0.0
azureIdentityVersion = 1.15.4
19 changes: 9 additions & 10 deletions redis-om-spring-ai/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@ description = 'Redis OM Spring AI'
dependencies {
implementation project(':redis-om-spring')

compileOnly "org.springframework.ai:spring-ai-openai:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-openai:${springAiVersion}"
compileOnly "jakarta.websocket:jakarta.websocket-api"
compileOnly "jakarta.websocket:jakarta.websocket-client-api"
compileOnly "org.springframework.ai:spring-ai-ollama:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-azure-openai:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-vertex-ai-embedding:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-bedrock:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-transformers:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-ollama:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-azure-openai:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-vertex-ai-embedding:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-bedrock:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-transformers:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-mistral-ai:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-minimax:${springAiVersion}"
compileOnly "org.springframework.ai:spring-ai-zhipuai:${springAiVersion}"

compileOnly "com.google.auto.service:auto-service:${autoServiceVersion}"

// DJL Dependencies
compileOnly "ai.djl.spring:djl-spring-boot-starter-autoconfigure:${djlStarterVersion}"
compileOnly "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:${djlStarterVersion}"
compileOnly "ai.djl.huggingface:tokenizers:${djlVersion}"

implementation "ai.djl.spring:djl-spring-boot-starter-autoconfigure:${djlStarterVersion}"
implementation "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:${djlStarterVersion}"
implementation "ai.djl.huggingface:tokenizers:${djlVersion}"

testImplementation "org.mockito:mockito-core"
testImplementation "com.karuslabs:elementary:${elementaryVersion}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.lang.Nullable;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import io.micrometer.observation.ObservationRegistry;

import com.redis.om.spring.vectorize.DefaultEmbedder;
import com.redis.om.spring.vectorize.Embedder;
Expand Down Expand Up @@ -79,9 +84,40 @@ public class AIRedisConfiguration {
//// }

@Bean
public EmbeddingModelFactory embeddingModelFactory(AIRedisOMProperties properties,
SpringAiProperties springAiProperties) {
return new EmbeddingModelFactory(properties, springAiProperties);
public RestClient.Builder restClientBuilder() {
return RestClient.builder();
}

@Bean
public WebClient.Builder webClientBuilder() {
return WebClient.builder();
}

@Bean
public ResponseErrorHandler defaultResponseErrorHandler() {
return new DefaultResponseErrorHandler();
}

@Bean
public ObservationRegistry observationRegistry() {
return ObservationRegistry.create();
}

@Bean
public EmbeddingModelFactory embeddingModelFactory(
AIRedisOMProperties properties,
SpringAiProperties springAiProperties,
RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler,
ObservationRegistry observationRegistry) {
return new EmbeddingModelFactory(
properties,
springAiProperties,
restClientBuilder,
webClientBuilder,
responseErrorHandler,
observationRegistry);
}

@Bean(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import io.micrometer.observation.ObservationRegistry;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
Expand All @@ -43,10 +46,24 @@
public class EmbeddingModelFactory {
private final AIRedisOMProperties properties;
private final SpringAiProperties springAiProperties;

public EmbeddingModelFactory(AIRedisOMProperties properties, SpringAiProperties springAiProperties) {
private final RestClient.Builder restClientBuilder;
private final WebClient.Builder webClientBuilder;
private final ResponseErrorHandler responseErrorHandler;
private final ObservationRegistry observationRegistry;

public EmbeddingModelFactory(
AIRedisOMProperties properties,
SpringAiProperties springAiProperties,
RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler,
ObservationRegistry observationRegistry) {
this.properties = properties;
this.springAiProperties = springAiProperties;
this.restClientBuilder = restClientBuilder;
this.webClientBuilder = webClientBuilder;
this.responseErrorHandler = responseErrorHandler;
this.observationRegistry = observationRegistry;
}

public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vectorize) {
Expand Down Expand Up @@ -174,7 +191,12 @@ public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model)
}

public OllamaEmbeddingModel createOllamaEmbeddingModel(String model) {
OllamaApi api = new OllamaApi(properties.getOllama().getBaseUrl());
OllamaApi api = OllamaApi.builder()
.baseUrl(properties.getOllama().getBaseUrl())
.restClientBuilder(restClientBuilder)
.webClientBuilder(webClientBuilder)
.responseErrorHandler(responseErrorHandler)
.build();

OllamaOptions options = OllamaOptions.builder().model(model).truncate(false).build();

Expand Down Expand Up @@ -228,6 +250,6 @@ public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
properties.getAws().getRegion(), ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(properties.getAws()
.getBedrockTitan().getResponseTimeOut()));

return new BedrockTitanEmbeddingModel(titanEmbeddingApi);
return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
}
}
3 changes: 2 additions & 1 deletion redis-om-spring/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies {
api "com.squareup:javapoet:${javapoetVersion}"

compileOnly "javax.enterprise:cdi-api:${cdi}"
compileOnly "com.google.auto.service:auto-service:${autoServiceVersion}"
implementation "com.google.auto.service:auto-service:${autoServiceVersion}"
annotationProcessor "com.google.auto.service:auto-service:${autoServiceVersion}"

}
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ private <T> T readTimeToLiveIfSet(@Nullable byte[] key, @Nullable T target) {
}

RedisPersistentEntity<?> entity = this.converter.getMappingContext().getRequiredPersistentEntity(target.getClass());
if (entity.hasExplictTimeToLiveProperty()) {
if (entity.hasExplicitTimeToLiveProperty()) {

RedisPersistentProperty ttlProperty = entity.getExplicitTimeToLiveProperty();
if (ttlProperty == null) {
Expand Down
1 change: 1 addition & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ rootProject.name = 'redis-om-spring-parent'

include 'redis-om-spring'
include 'redis-om-spring-ai'
include 'tests'
105 changes: 105 additions & 0 deletions tests/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
plugins {
id 'java'
}

apply plugin: 'org.springframework.boot'
apply plugin: 'io.spring.dependency-management'

java {
toolchain {
languageVersion = JavaLanguageVersion.of(21)
}
}

// Don't publish this module
bootJar { enabled = false }
jar { enabled = false }
tasks.matching { it.name.startsWith('publish') }.configureEach {
enabled = false
}

repositories {
mavenLocal()
mavenCentral()
}

// Tell gradle to add the generated sources directory
sourceSets {
test {
java {
srcDir file("${buildDir}/generated/sources/annotationProcessor/java/test")
}
}
}

dependencies {
implementation project(':redis-om-spring')
implementation project(':redis-om-spring-ai')

// Important for RedisOM annotation processing!
annotationProcessor project(':redis-om-spring')
testAnnotationProcessor project(':redis-om-spring')

// Lombok
compileOnly 'org.projectlombok:lombok'
annotationProcessor 'org.projectlombok:lombok'
testCompileOnly 'org.projectlombok:lombok'
testAnnotationProcessor 'org.projectlombok:lombok'

// Spring
implementation 'org.springframework:spring-context-support'
implementation 'org.springframework.boot:spring-boot-starter-test'

// Spring AI
implementation "org.springframework.ai:spring-ai-openai:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-ollama:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-azure-openai:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-vertex-ai-embedding:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-bedrock:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-transformers:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-mistral-ai:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-minimax:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-zhipuai:${springAiVersion}"

// WebSocket
implementation 'jakarta.websocket:jakarta.websocket-api:2.1.1'
implementation 'jakarta.websocket:jakarta.websocket-client-api:2.1.1'

// DJL
implementation "ai.djl.spring:djl-spring-boot-starter-autoconfigure:${djlStarterVersion}"
implementation "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:${djlStarterVersion}"
implementation "ai.djl.huggingface:tokenizers:${djlVersion}"

// Test
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.assertj:assertj-core'
testImplementation 'org.mockito:mockito-core'
testImplementation "com.redis:testcontainers-redis:${testcontainersRedisVersion}"
testImplementation "com.karuslabs:elementary:${elementaryVersion}"
testImplementation "org.testcontainers:junit-jupiter"

// Other
implementation 'com.fasterxml.jackson.core:jackson-databind'
compileOnly "javax.enterprise:cdi-api:${cdi}"
}

// Use -parameters flag for Spring
tasks.withType(JavaCompile).configureEach {
options.compilerArgs << '-parameters'
}

// Configure annotation processing
compileTestJava {
options.annotationProcessorPath = configurations.testAnnotationProcessor
options.annotationProcessorGeneratedSourcesDirectory = file("${buildDir}/generated/sources/annotationProcessor/java/test")
}

test {
useJUnitPlatform()
maxHeapSize = "1g"

testLogging {
events "passed", "skipped", "failed"
exceptionFormat = 'full'
}
}