forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integration with Google Vertex AI (langchain4j#135)
- Loading branch information
Showing
15 changed files
with
786 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.21.0</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-vertex-ai</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<name>LangChain4j integration with Vertex AI</name> | ||
|
||
<dependencies> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
<version>${project.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.google.cloud</groupId> | ||
<artifactId>google-cloud-aiplatform</artifactId> | ||
<version>3.24.0</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
|
||
<licenses> | ||
<license> | ||
<name>Apache-2.0</name> | ||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url> | ||
<distribution>repo</distribution> | ||
<comments>A business-friendly OSS license</comments> | ||
</license> | ||
</licenses> | ||
|
||
</project> |
25 changes: 25 additions & 0 deletions
25
langchain4j-vertex-ai/src/main/java/dev/langchain4j/model/vertexai/VertexAiChatInstance.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import java.util.List; | ||
|
||
class VertexAiChatInstance { | ||
|
||
private final String context; | ||
private final List<Message> messages; | ||
|
||
VertexAiChatInstance(String context, List<Message> messages) { | ||
this.context = context; | ||
this.messages = messages; | ||
} | ||
|
||
static class Message { | ||
|
||
private final String author; | ||
private final String content; | ||
|
||
Message(String author, String content) { | ||
this.author = author; | ||
this.content = content; | ||
} | ||
} | ||
} |
213 changes: 213 additions & 0 deletions
213
langchain4j-vertex-ai/src/main/java/dev/langchain4j/model/vertexai/VertexAiChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.aiplatform.v1.EndpointName; | ||
import com.google.cloud.aiplatform.v1.PredictResponse; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceClient; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceSettings; | ||
import com.google.protobuf.Value; | ||
import com.google.protobuf.util.JsonFormat; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import static com.google.protobuf.Value.newBuilder; | ||
import static dev.langchain4j.data.message.AiMessage.aiMessage; | ||
import static dev.langchain4j.data.message.ChatMessageType.AI; | ||
import static dev.langchain4j.data.message.ChatMessageType.SYSTEM; | ||
import static dev.langchain4j.data.message.ChatMessageType.USER; | ||
import static dev.langchain4j.internal.Json.toJson; | ||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.util.Collections.singletonList; | ||
import static java.util.stream.Collectors.joining; | ||
import static java.util.stream.Collectors.toList; | ||
|
||
/** | ||
* Represents a connection to the Vertex AI LLM with a chat completion interface, such as chat-bison. | ||
* See details <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts">here</a>. | ||
*/ | ||
public class VertexAiChatModel implements ChatLanguageModel { | ||
|
||
private final PredictionServiceSettings settings; | ||
private final EndpointName endpointName; | ||
private final VertexAiParameters vertexAiParameters; | ||
private final Integer maxRetries; | ||
|
||
public VertexAiChatModel(String endpoint, | ||
String project, | ||
String location, | ||
String publisher, | ||
String modelName, | ||
Double temperature, | ||
Integer maxOutputTokens, | ||
Integer topK, | ||
Double topP, | ||
Integer maxRetries) { | ||
try { | ||
this.settings = PredictionServiceSettings.newBuilder() | ||
.setEndpoint(ensureNotBlank(endpoint, "endpoint")) | ||
.build(); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
this.endpointName = EndpointName.ofProjectLocationPublisherModelName( | ||
ensureNotBlank(project, "project"), | ||
ensureNotBlank(location, "location"), | ||
ensureNotBlank(publisher, "publisher"), | ||
ensureNotBlank(modelName, "modelName") | ||
); | ||
this.vertexAiParameters = new VertexAiParameters(temperature, maxOutputTokens, topK, topP); | ||
this.maxRetries = maxRetries == null ? 3 : maxRetries; | ||
} | ||
|
||
@Override | ||
public AiMessage sendMessages(List<ChatMessage> messages) { | ||
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) { | ||
|
||
VertexAiChatInstance vertexAiChatInstance = new VertexAiChatInstance( | ||
toContext(messages), | ||
toVertexMessages(messages) | ||
); | ||
|
||
Value.Builder instanceBuilder = newBuilder(); | ||
JsonFormat.parser().merge(toJson(vertexAiChatInstance), instanceBuilder); | ||
List<Value> instances = singletonList(instanceBuilder.build()); | ||
|
||
Value.Builder parametersBuilder = newBuilder(); | ||
JsonFormat.parser().merge(toJson(vertexAiParameters), parametersBuilder); | ||
Value parameters = parametersBuilder.build(); | ||
|
||
PredictResponse response = withRetry(() -> client.predict(endpointName, instances, parameters), maxRetries); | ||
|
||
return aiMessage(extractContent(response)); | ||
|
||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private static String extractContent(PredictResponse predictResponse) { | ||
return predictResponse.getPredictions(0) | ||
.getStructValue() | ||
.getFieldsMap() | ||
.get("candidates") | ||
.getListValue() | ||
.getValues(0) | ||
.getStructValue() | ||
.getFieldsMap() | ||
.get("content") | ||
.getStringValue(); | ||
} | ||
|
||
private static List<VertexAiChatInstance.Message> toVertexMessages(List<ChatMessage> messages) { | ||
return messages.stream() | ||
.filter(chatMessage -> chatMessage.type() == USER || chatMessage.type() == AI) | ||
.map(chatMessage -> new VertexAiChatInstance.Message(chatMessage.type().name(), chatMessage.text())) | ||
.collect(toList()); | ||
} | ||
|
||
private static String toContext(List<ChatMessage> messages) { | ||
return messages.stream() | ||
.filter(chatMessage -> chatMessage.type() == SYSTEM) | ||
.map(ChatMessage::text) | ||
.collect(joining("\n")); | ||
} | ||
|
||
@Override | ||
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) { | ||
throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models"); | ||
} | ||
|
||
@Override | ||
public AiMessage sendMessages(List<ChatMessage> messages, ToolSpecification toolSpecification) { | ||
throw new IllegalArgumentException("Tools are currently not supported for Vertex AI models"); | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
|
||
private String endpoint; | ||
private String project; | ||
private String location; | ||
private String publisher; | ||
private String modelName; | ||
|
||
private Double temperature; | ||
private Integer maxOutputTokens = 200; | ||
private Integer topK; | ||
private Double topP; | ||
|
||
private Integer maxRetries; | ||
|
||
public Builder endpoint(String endpoint) { | ||
this.endpoint = endpoint; | ||
return this; | ||
} | ||
|
||
public Builder project(String project) { | ||
this.project = project; | ||
return this; | ||
} | ||
|
||
public Builder location(String location) { | ||
this.location = location; | ||
return this; | ||
} | ||
|
||
public Builder publisher(String publisher) { | ||
this.publisher = publisher; | ||
return this; | ||
} | ||
|
||
public Builder modelName(String modelName) { | ||
this.modelName = modelName; | ||
return this; | ||
} | ||
|
||
public Builder temperature(Double temperature) { | ||
this.temperature = temperature; | ||
return this; | ||
} | ||
|
||
public Builder maxOutputTokens(Integer maxOutputTokens) { | ||
this.maxOutputTokens = maxOutputTokens; | ||
return this; | ||
} | ||
|
||
public Builder topK(Integer topK) { | ||
this.topK = topK; | ||
return this; | ||
} | ||
|
||
public Builder topP(Double topP) { | ||
this.topP = topP; | ||
return this; | ||
} | ||
|
||
public Builder maxRetries(Integer maxRetries) { | ||
this.maxRetries = maxRetries; | ||
return this; | ||
} | ||
|
||
public VertexAiChatModel build() { | ||
return new VertexAiChatModel( | ||
endpoint, | ||
project, | ||
location, | ||
publisher, | ||
modelName, | ||
temperature, | ||
maxOutputTokens, | ||
topK, | ||
topP, | ||
maxRetries); | ||
} | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
...n4j-vertex-ai/src/main/java/dev/langchain4j/model/vertexai/VertexAiEmbeddingInstance.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
class VertexAiEmbeddingInstance { | ||
|
||
private final String content; | ||
|
||
VertexAiEmbeddingInstance(String content) { | ||
this.content = content; | ||
} | ||
} |
Oops, something went wrong.