Skip to content

Commit

Permalink
Integration with Google Vertex AI (langchain4j#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuraleta authored Aug 28, 2023
1 parent 20753a9 commit 88b5677
Show file tree
Hide file tree
Showing 15 changed files with 786 additions and 60 deletions.
5 changes: 5 additions & 0 deletions langchain4j-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
<artifactId>compiler</artifactId>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package dev.langchain4j.internal;

import dev.ai4j.openai4j.OpenAiHttpException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -10,16 +9,11 @@

public class RetryUtils {

private static final int HTTP_CODE_401_UNAUTHORIZED = 401;
private static final int HTTP_CODE_429_TOO_MANY_REQUESTS = 429;

private static final Logger log = LoggerFactory.getLogger(RetryUtils.class);

/**
* This method attempts to execute a given action up to a specified number of times.
* This method attempts to execute a given action up to a specified number of times with a 1-second delay.
* If the action fails on all attempts, it throws a RuntimeException.
* Retry will not happen for 401 (Unauthorized).
* Retry will happen after 1-second delay for 429 (Too many requests).
*
* @param action The action to be executed.
* @param maxAttempts The maximum number of attempts to execute the action.
Expand All @@ -30,29 +24,17 @@ public static <T> T withRetry(Callable<T> action, int maxAttempts) {
for (int attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return action.call();
} catch (OpenAiHttpException e) {
} catch (Exception e) {
if (attempt == maxAttempts) {
throw new RuntimeException(e);
}

if (e.code() == HTTP_CODE_401_UNAUTHORIZED) {
throw new RuntimeException(e); // makes no sense to retry
}

log.warn(format("Exception was thrown on attempt %s of %s", attempt, maxAttempts), e);

if (e.code() == HTTP_CODE_429_TOO_MANY_REQUESTS) {
try {
// TODO make configurable or read from Retry-After
Thread.sleep(1000); // makes sense to retry after a bit of waiting
} catch (InterruptedException ignored) {
}
try {
Thread.sleep(1000); // TODO make configurable
} catch (InterruptedException ignored) {
}
} catch (Exception e) {
if (attempt == maxAttempts) {
throw new RuntimeException(e);
}
log.warn(format("Exception was thrown on attempt %s of %s", attempt, maxAttempts), e);
}
}
throw new RuntimeException("Failed after " + maxAttempts + " attempts");
Expand Down
61 changes: 61 additions & 0 deletions langchain4j-vertex-ai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.21.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-vertex-ai</artifactId>
<packaging>jar</packaging>

<name>LangChain4j integration with Vertex AI</name>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-aiplatform</artifactId>
<version>3.24.0</version>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dev.langchain4j.model.vertexai;

import java.util.List;

class VertexAiChatInstance {

private final String context;
private final List<Message> messages;

VertexAiChatInstance(String context, List<Message> messages) {
this.context = context;
this.messages = messages;
}

static class Message {

private final String author;
private final String content;

Message(String author, String content) {
this.author = author;
this.content = content;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;

import java.io.IOException;
import java.util.List;

import static com.google.protobuf.Value.newBuilder;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.data.message.ChatMessageType.SYSTEM;
import static dev.langchain4j.data.message.ChatMessageType.USER;
import static dev.langchain4j.internal.Json.toJson;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

/**
* Represents a connection to the Vertex AI LLM with a chat completion interface, such as chat-bison.
* See details <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts">here</a>.
*/
public class VertexAiChatModel implements ChatLanguageModel {

private final PredictionServiceSettings settings;
private final EndpointName endpointName;
private final VertexAiParameters vertexAiParameters;
private final Integer maxRetries;

public VertexAiChatModel(String endpoint,
String project,
String location,
String publisher,
String modelName,
Double temperature,
Integer maxOutputTokens,
Integer topK,
Double topP,
Integer maxRetries) {
try {
this.settings = PredictionServiceSettings.newBuilder()
.setEndpoint(ensureNotBlank(endpoint, "endpoint"))
.build();
} catch (IOException e) {
throw new RuntimeException(e);
}
this.endpointName = EndpointName.ofProjectLocationPublisherModelName(
ensureNotBlank(project, "project"),
ensureNotBlank(location, "location"),
ensureNotBlank(publisher, "publisher"),
ensureNotBlank(modelName, "modelName")
);
this.vertexAiParameters = new VertexAiParameters(temperature, maxOutputTokens, topK, topP);
this.maxRetries = maxRetries == null ? 3 : maxRetries;
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {

VertexAiChatInstance vertexAiChatInstance = new VertexAiChatInstance(
toContext(messages),
toVertexMessages(messages)
);

Value.Builder instanceBuilder = newBuilder();
JsonFormat.parser().merge(toJson(vertexAiChatInstance), instanceBuilder);
List<Value> instances = singletonList(instanceBuilder.build());

Value.Builder parametersBuilder = newBuilder();
JsonFormat.parser().merge(toJson(vertexAiParameters), parametersBuilder);
Value parameters = parametersBuilder.build();

PredictResponse response = withRetry(() -> client.predict(endpointName, instances, parameters), maxRetries);

return aiMessage(extractContent(response));

} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static String extractContent(PredictResponse predictResponse) {
return predictResponse.getPredictions(0)
.getStructValue()
.getFieldsMap()
.get("candidates")
.getListValue()
.getValues(0)
.getStructValue()
.getFieldsMap()
.get("content")
.getStringValue();
}

private static List<VertexAiChatInstance.Message> toVertexMessages(List<ChatMessage> messages) {
return messages.stream()
.filter(chatMessage -> chatMessage.type() == USER || chatMessage.type() == AI)
.map(chatMessage -> new VertexAiChatInstance.Message(chatMessage.type().name(), chatMessage.text()))
.collect(toList());
}

private static String toContext(List<ChatMessage> messages) {
return messages.stream()
.filter(chatMessage -> chatMessage.type() == SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models");
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages, ToolSpecification toolSpecification) {
throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models");
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private String endpoint;
private String project;
private String location;
private String publisher;
private String modelName;

private Double temperature;
private Integer maxOutputTokens = 200;
private Integer topK;
private Double topP;

private Integer maxRetries;

public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

public Builder project(String project) {
this.project = project;
return this;
}

public Builder location(String location) {
this.location = location;
return this;
}

public Builder publisher(String publisher) {
this.publisher = publisher;
return this;
}

public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}

public Builder temperature(Double temperature) {
this.temperature = temperature;
return this;
}

public Builder maxOutputTokens(Integer maxOutputTokens) {
this.maxOutputTokens = maxOutputTokens;
return this;
}

public Builder topK(Integer topK) {
this.topK = topK;
return this;
}

public Builder topP(Double topP) {
this.topP = topP;
return this;
}

public Builder maxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}

public VertexAiChatModel build() {
return new VertexAiChatModel(
endpoint,
project,
location,
publisher,
modelName,
temperature,
maxOutputTokens,
topK,
topP,
maxRetries);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.vertexai;

class VertexAiEmbeddingInstance {

private final String content;

VertexAiEmbeddingInstance(String content) {
this.content = content;
}
}
Loading

0 comments on commit 88b5677

Please sign in to comment.