Skip to content

Add retry support to VertexAI embedding and chat models #1437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-vertex-ai-embedding/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;

/**
* VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex
* VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex
* AI embedding service. It provides methods to access the project ID, location,
* publisher, and PredictionServiceSettings.
*
* @author Christian Tzolov
* @since 1.0.0
*/
public class VertexAiEmbeddigConnectionDetails {
public class VertexAiEmbeddingConnectionDetails {

private static final String DEFAULT_LOCATION = "us-central1";

Expand Down Expand Up @@ -55,7 +58,7 @@ public class VertexAiEmbeddigConnectionDetails {

private final String publisher;

public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) {
public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) {
this.projectId = projectId;
this.location = location;
this.publisher = publisher;
Expand Down Expand Up @@ -119,7 +122,7 @@ public Builder withPublisher(String publisher) {
return this;
}

public VertexAiEmbeddigConnectionDetails build() {
public VertexAiEmbeddingConnectionDetails build() {
if (!StringUtils.hasText(this.endpoint)) {
if (!StringUtils.hasText(this.location)) {
this.endpoint = DEFAULT_ENDPOINT;
Expand All @@ -134,7 +137,7 @@ public VertexAiEmbeddigConnectionDetails build() {
this.publisher = DEFAULT_PUBLISHER;
}

return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
return new VertexAiEmbeddingConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
Expand Down Expand Up @@ -76,9 +76,9 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));

private final VertexAiEmbeddigConnectionDetails connectionDetails;
private final VertexAiEmbeddingConnectionDetails connectionDetails;

public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) {

Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -53,16 +56,22 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {

public final VertexAiTextEmbeddingOptions defaultOptions;

private final VertexAiEmbeddigConnectionDetails connectionDetails;
private final VertexAiEmbeddingConnectionDetails connectionDetails;

private final RetryTemplate retryTemplate;

public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");

Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();

this.connectionDetails = connectionDetails;
this.retryTemplate = retryTemplate;
}

@Override
Expand All @@ -73,46 +82,23 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
return retryTemplate.execute(context -> {
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;

VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;

if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}

try (PredictionServiceClient client = PredictionServiceClient
.create(this.connectionDetails.getPredictionServiceSettings())) {

EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());

PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder()
.setEndpoint(endpointName.toString());

TextParametersBuilder parametersBuilder = TextParametersBuilder.of();

if (finalOptions.getAutoTruncate() != null) {
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
}

if (finalOptions.getDimensions() != null) {
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}

predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
PredictionServiceClient client = createPredictionServiceClient();

for (int i = 0; i < request.getInstructions().size(); i++) {
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());

TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
.withTaskType(finalOptions.getTaskType().name());
if (StringUtils.hasText(finalOptions.getTitle())) {
instanceBuilder.withTitle(finalOptions.getTitle());
}
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
}
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
finalOptions);

PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);

int index = 0;
int totalTokenCount = 0;
Expand All @@ -131,12 +117,53 @@ public EmbeddingResponse call(EmbeddingRequest request) {
}
return new EmbeddingResponse(embeddingList,
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
});
}

protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
VertexAiTextEmbeddingOptions finalOptions) {
PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString());

TextParametersBuilder parametersBuilder = TextParametersBuilder.of();

if (finalOptions.getAutoTruncate() != null) {
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
}
catch (Exception e) {

if (finalOptions.getDimensions() != null) {
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
}

predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));

for (int i = 0; i < request.getInstructions().size(); i++) {

TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
.withTaskType(finalOptions.getTaskType().name());
if (StringUtils.hasText(finalOptions.getTitle())) {
instanceBuilder.withTitle(finalOptions.getTitle());
}
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
}
return predictRequestBuilder;
}

// for testing
PredictionServiceClient createPredictionServiceClient() {
try {
return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings());
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

// for testing
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
return embeddingResponse;
}

private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.setModel(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -213,16 +213,16 @@ void textImageAndVideoEmbedding() {
static class Config {

@Bean
public VertexAiEmbeddigConnectionDetails connectionDetails() {
return VertexAiEmbeddigConnectionDetails.builder()
public VertexAiEmbeddingConnectionDetails connectionDetails() {
return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}

@Bean
public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel(
VertexAiEmbeddigConnectionDetails connectionDetails) {
VertexAiEmbeddingConnectionDetails connectionDetails) {

VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder()
.withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.springframework.ai.vertexai.embedding.text;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.retry.support.RetryTemplate;

import java.io.IOException;

public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel {

private PredictionServiceClient mockPredictionServiceClient;

private PredictRequest.Builder mockPredictRequestBuilder;

public TestVertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
super(connectionDetails, defaultEmbeddingOptions, retryTemplate);
}

public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) {
this.mockPredictionServiceClient = mockPredictionServiceClient;
}

@Override
PredictionServiceClient createPredictionServiceClient() {
if (mockPredictionServiceClient != null) {
return mockPredictionServiceClient;
}
return super.createPredictionServiceClient();
}

@Override
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
if (mockPredictionServiceClient != null) {
return mockPredictionServiceClient.predict(predictRequestBuilder.build());
}
return super.getPredictResponse(client, predictRequestBuilder);
}

public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) {
this.mockPredictRequestBuilder = mockPredictRequestBuilder;
}

@Override
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
VertexAiTextEmbeddingOptions finalOptions) {
if (mockPredictRequestBuilder != null) {
return mockPredictRequestBuilder;
}
return super.getPredictRequestBuilder(request, endpointName, finalOptions);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -67,15 +67,15 @@ void defaultEmbedding(String modelName) {
static class Config {

@Bean
public VertexAiEmbeddigConnectionDetails connectionDetails() {
return VertexAiEmbeddigConnectionDetails.builder()
public VertexAiEmbeddingConnectionDetails connectionDetails() {
return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}

@Bean
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) {
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) {

VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
Expand Down
Loading