forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Google Vertex AI Gemini (langchain4j#402)
This PR adds a basic support of Gemini (text). Next time images and tools will be added. --------- Co-authored-by: kuraleta <digital.kuraleta@gmail.com>
- Loading branch information
1 parent
6ca2790
commit f185438
Showing
17 changed files
with
707 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,6 @@ public enum FinishReason { | |
STOP, | ||
LENGTH, | ||
TOOL_EXECUTION, | ||
CONTENT_FILTER | ||
CONTENT_FILTER, | ||
OTHER | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.25.0-SNAPSHOT</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-vertex-ai-gemini</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<name>LangChain4j integration with Google Vertex AI Gemini</name> | ||
|
||
<dependencies> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.google.cloud</groupId> | ||
<artifactId>google-cloud-vertexai</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
</dependencies> | ||
|
||
<dependencyManagement> | ||
<dependencies> | ||
<dependency> | ||
<groupId>com.google.cloud</groupId> | ||
<artifactId>libraries-bom</artifactId> | ||
<scope>import</scope> | ||
<type>pom</type> | ||
<version>26.29.0</version> | ||
</dependency> | ||
</dependencies> | ||
</dependencyManagement> | ||
|
||
<licenses> | ||
<license> | ||
<name>Apache-2.0</name> | ||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url> | ||
<distribution>repo</distribution> | ||
<comments>A business-friendly OSS license</comments> | ||
</license> | ||
</licenses> | ||
|
||
</project> |
29 changes: 29 additions & 0 deletions
29
...chain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.api.Content; | ||
import com.google.cloud.vertexai.api.Part; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.SystemMessage; | ||
|
||
import java.util.List; | ||
|
||
import static java.util.stream.Collectors.toList; | ||
|
||
class ContentsMapper { | ||
|
||
static List<Content> map(List<ChatMessage> messages) { | ||
return messages.stream() | ||
.peek(message -> { | ||
if (message instanceof SystemMessage) { | ||
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini"); | ||
} | ||
}) | ||
.map(message -> Content.newBuilder() | ||
.setRole(RoleMapper.map(message.type())) | ||
.addParts(Part.newBuilder() | ||
.setText(message.text()) | ||
.build()) | ||
.build()) | ||
.collect(toList()); | ||
} | ||
} |
19 changes: 19 additions & 0 deletions
19
...n4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/FinishReasonMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.api.Candidate; | ||
import dev.langchain4j.model.output.FinishReason; | ||
|
||
class FinishReasonMapper { | ||
|
||
static FinishReason map(Candidate.FinishReason finishReason) { | ||
switch (finishReason) { | ||
case STOP: | ||
return FinishReason.STOP; | ||
case MAX_TOKENS: | ||
return FinishReason.LENGTH; | ||
case SAFETY: | ||
return FinishReason.CONTENT_FILTER; | ||
} | ||
return FinishReason.OTHER; | ||
} | ||
} |
17 changes: 17 additions & 0 deletions
17
langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/RoleMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import dev.langchain4j.data.message.ChatMessageType; | ||
|
||
class RoleMapper { | ||
|
||
static String map(ChatMessageType type) { | ||
switch (type) { | ||
case SYSTEM: | ||
case USER: | ||
return "user"; | ||
case AI: | ||
return "model"; | ||
} | ||
throw new IllegalArgumentException(type + " is not allowed."); | ||
} | ||
} |
48 changes: 48 additions & 0 deletions
48
...-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/StreamingChatResponseBuilder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.api.Candidate; | ||
import com.google.cloud.vertexai.api.GenerateContentResponse; | ||
import com.google.cloud.vertexai.generativeai.preview.ResponseHandler; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.output.FinishReason; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
|
||
import java.util.List; | ||
|
||
class StreamingChatResponseBuilder { | ||
|
||
private final StringBuffer contentBuilder = new StringBuffer(); | ||
private volatile TokenUsage tokenUsage; | ||
private volatile FinishReason finishReason; | ||
|
||
void append(GenerateContentResponse partialResponse) { | ||
if (partialResponse == null) { | ||
return; | ||
} | ||
|
||
List<Candidate> candidates = partialResponse.getCandidatesList(); | ||
if (candidates.isEmpty() || candidates.get(0) == null) { | ||
return; | ||
} | ||
|
||
contentBuilder.append(ResponseHandler.getText(partialResponse)); | ||
|
||
if (partialResponse.hasUsageMetadata()) { | ||
tokenUsage = TokenUsageMapper.map(partialResponse.getUsageMetadata()); | ||
} | ||
|
||
Candidate.FinishReason finishReason = ResponseHandler.getFinishReason(partialResponse); | ||
if (finishReason != null) { | ||
this.finishReason = FinishReasonMapper.map(finishReason); | ||
} | ||
} | ||
|
||
Response<AiMessage> build() { | ||
return Response.from( | ||
AiMessage.from(contentBuilder.toString()), | ||
tokenUsage, | ||
finishReason | ||
); | ||
} | ||
} |
15 changes: 15 additions & 0 deletions
15
...ain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/TokenUsageMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.api.GenerateContentResponse; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
|
||
class TokenUsageMapper { | ||
|
||
static TokenUsage map(GenerateContentResponse.UsageMetadata usageMetadata) { | ||
return new TokenUsage( | ||
usageMetadata.getPromptTokenCount(), | ||
usageMetadata.getCandidatesTokenCount(), | ||
usageMetadata.getTotalTokenCount() | ||
); | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
...ertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package dev.langchain4j.model.vertexai; | ||
|
||
import com.google.cloud.vertexai.VertexAI; | ||
import com.google.cloud.vertexai.api.Content; | ||
import com.google.cloud.vertexai.api.GenerateContentResponse; | ||
import com.google.cloud.vertexai.api.GenerationConfig; | ||
import com.google.cloud.vertexai.generativeai.preview.GenerativeModel; | ||
import com.google.cloud.vertexai.generativeai.preview.ResponseHandler; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Builder; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; | ||
|
||
/** | ||
* Represents a Google Vertex AI Gemini language model with a chat completion interface, such as gemini-pro. | ||
* See details <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini">here</a>. | ||
*/ | ||
public class VertexAiGeminiChatModel implements ChatLanguageModel { | ||
|
||
private final GenerativeModel generativeModel; | ||
private final GenerationConfig generationConfig; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public VertexAiGeminiChatModel(String project, | ||
String location, | ||
String modelName, | ||
Float temperature, | ||
Integer maxOutputTokens, | ||
Integer topK, | ||
Float topP, | ||
Integer maxRetries) { | ||
|
||
try (VertexAI vertexAI = new VertexAI( | ||
ensureNotBlank(project, "project"), | ||
ensureNotBlank(location, "location")) | ||
) { | ||
this.generativeModel = new GenerativeModel(ensureNotBlank(modelName, "modelName"), vertexAI); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
|
||
GenerationConfig.Builder generationConfigBuilder = GenerationConfig.newBuilder(); | ||
if (temperature != null) { | ||
generationConfigBuilder.setTemperature(temperature); | ||
} | ||
if (maxOutputTokens != null) { | ||
generationConfigBuilder.setMaxOutputTokens(maxOutputTokens); | ||
} | ||
if (topK != null) { | ||
generationConfigBuilder.setTopK(topK); | ||
} | ||
if (topP != null) { | ||
generationConfigBuilder.setTopP(topP); | ||
} | ||
this.generationConfig = generationConfigBuilder.build(); | ||
|
||
this.maxRetries = getOrDefault(maxRetries, 3); | ||
} | ||
|
||
public VertexAiGeminiChatModel(GenerativeModel generativeModel, | ||
GenerationConfig generationConfig) { | ||
this.generativeModel = ensureNotNull(generativeModel, "generativeModel"); | ||
this.generationConfig = ensureNotNull(generationConfig, "generationConfig"); | ||
this.maxRetries = 3; | ||
} | ||
|
||
public VertexAiGeminiChatModel(GenerativeModel generativeModel, | ||
GenerationConfig generationConfig, | ||
Integer maxRetries) { | ||
this.generativeModel = ensureNotNull(generativeModel, "generativeModel"); | ||
this.generationConfig = ensureNotNull(generationConfig, "generationConfig"); | ||
this.maxRetries = getOrDefault(maxRetries, 3); | ||
} | ||
|
||
@Override | ||
public Response<AiMessage> generate(List<ChatMessage> messages) { | ||
|
||
List<Content> contents = ContentsMapper.map(messages); | ||
|
||
GenerateContentResponse response = withRetry(() -> generativeModel.generateContent(contents, generationConfig), maxRetries); | ||
|
||
return Response.from( | ||
AiMessage.from(ResponseHandler.getText(response)), | ||
TokenUsageMapper.map(response.getUsageMetadata()), | ||
FinishReasonMapper.map(ResponseHandler.getFinishReason(response)) | ||
); | ||
} | ||
} |
Oops, something went wrong.