Skip to content

Commit

Permalink
Support Google Vertex AI Gemini (langchain4j#402)
Browse files Browse the repository at this point in the history
This PR adds a basic support of Gemini (text).
Next time images and tools will be added.

---------

Co-authored-by: kuraleta <digital.kuraleta@gmail.com>
  • Loading branch information
langchain4j and kuraleta authored Dec 22, 2023
1 parent 6ca2790 commit f185438
Show file tree
Hide file tree
Showing 17 changed files with 707 additions and 6 deletions.
6 changes: 6 additions & 0 deletions langchain4j-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-vertex-ai-gemini</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ public enum FinishReason {
STOP,
LENGTH,
TOOL_EXECUTION,
CONTENT_FILTER
CONTENT_FILTER,
OTHER
}
78 changes: 78 additions & 0 deletions langchain4j-vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.25.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-vertex-ai-gemini</artifactId>
<packaging>jar</packaging>

<name>LangChain4j integration with Google Vertex AI Gemini</name>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-vertexai</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>libraries-bom</artifactId>
<scope>import</scope>
<type>pom</type>
<version>26.29.0</version>
</dependency>
</dependencies>
</dependencyManagement>

<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.Part;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;

import java.util.List;

import static java.util.stream.Collectors.toList;

class ContentsMapper {

static List<Content> map(List<ChatMessage> messages) {
return messages.stream()
.peek(message -> {
if (message instanceof SystemMessage) {
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini");
}
})
.map(message -> Content.newBuilder()
.setRole(RoleMapper.map(message.type()))
.addParts(Part.newBuilder()
.setText(message.text())
.build())
.build())
.collect(toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.Candidate;
import dev.langchain4j.model.output.FinishReason;

class FinishReasonMapper {

static FinishReason map(Candidate.FinishReason finishReason) {
switch (finishReason) {
case STOP:
return FinishReason.STOP;
case MAX_TOKENS:
return FinishReason.LENGTH;
case SAFETY:
return FinishReason.CONTENT_FILTER;
}
return FinishReason.OTHER;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.langchain4j.model.vertexai;

import dev.langchain4j.data.message.ChatMessageType;

class RoleMapper {

static String map(ChatMessageType type) {
switch (type) {
case SYSTEM:
case USER:
return "user";
case AI:
return "model";
}
throw new IllegalArgumentException(type + " is not allowed.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.Candidate;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.generativeai.preview.ResponseHandler;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;

import java.util.List;

class StreamingChatResponseBuilder {

private final StringBuffer contentBuilder = new StringBuffer();
private volatile TokenUsage tokenUsage;
private volatile FinishReason finishReason;

void append(GenerateContentResponse partialResponse) {
if (partialResponse == null) {
return;
}

List<Candidate> candidates = partialResponse.getCandidatesList();
if (candidates.isEmpty() || candidates.get(0) == null) {
return;
}

contentBuilder.append(ResponseHandler.getText(partialResponse));

if (partialResponse.hasUsageMetadata()) {
tokenUsage = TokenUsageMapper.map(partialResponse.getUsageMetadata());
}

Candidate.FinishReason finishReason = ResponseHandler.getFinishReason(partialResponse);
if (finishReason != null) {
this.finishReason = FinishReasonMapper.map(finishReason);
}
}

Response<AiMessage> build() {
return Response.from(
AiMessage.from(contentBuilder.toString()),
tokenUsage,
finishReason
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.GenerateContentResponse;
import dev.langchain4j.model.output.TokenUsage;

class TokenUsageMapper {

static TokenUsage map(GenerateContentResponse.UsageMetadata usageMetadata) {
return new TokenUsage(
usageMetadata.getPromptTokenCount(),
usageMetadata.getCandidatesTokenCount(),
usageMetadata.getTotalTokenCount()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.generativeai.preview.GenerativeModel;
import com.google.cloud.vertexai.generativeai.preview.ResponseHandler;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.io.IOException;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Represents a Google Vertex AI Gemini language model with a chat completion interface, such as gemini-pro.
* See details <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini">here</a>.
*/
public class VertexAiGeminiChatModel implements ChatLanguageModel {

private final GenerativeModel generativeModel;
private final GenerationConfig generationConfig;
private final Integer maxRetries;

@Builder
public VertexAiGeminiChatModel(String project,
String location,
String modelName,
Float temperature,
Integer maxOutputTokens,
Integer topK,
Float topP,
Integer maxRetries) {

try (VertexAI vertexAI = new VertexAI(
ensureNotBlank(project, "project"),
ensureNotBlank(location, "location"))
) {
this.generativeModel = new GenerativeModel(ensureNotBlank(modelName, "modelName"), vertexAI);
} catch (IOException e) {
throw new RuntimeException(e);
}

GenerationConfig.Builder generationConfigBuilder = GenerationConfig.newBuilder();
if (temperature != null) {
generationConfigBuilder.setTemperature(temperature);
}
if (maxOutputTokens != null) {
generationConfigBuilder.setMaxOutputTokens(maxOutputTokens);
}
if (topK != null) {
generationConfigBuilder.setTopK(topK);
}
if (topP != null) {
generationConfigBuilder.setTopP(topP);
}
this.generationConfig = generationConfigBuilder.build();

this.maxRetries = getOrDefault(maxRetries, 3);
}

public VertexAiGeminiChatModel(GenerativeModel generativeModel,
GenerationConfig generationConfig) {
this.generativeModel = ensureNotNull(generativeModel, "generativeModel");
this.generationConfig = ensureNotNull(generationConfig, "generationConfig");
this.maxRetries = 3;
}

public VertexAiGeminiChatModel(GenerativeModel generativeModel,
GenerationConfig generationConfig,
Integer maxRetries) {
this.generativeModel = ensureNotNull(generativeModel, "generativeModel");
this.generationConfig = ensureNotNull(generationConfig, "generationConfig");
this.maxRetries = getOrDefault(maxRetries, 3);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

List<Content> contents = ContentsMapper.map(messages);

GenerateContentResponse response = withRetry(() -> generativeModel.generateContent(contents, generationConfig), maxRetries);

return Response.from(
AiMessage.from(ResponseHandler.getText(response)),
TokenUsageMapper.map(response.getUsageMetadata()),
FinishReasonMapper.map(ResponseHandler.getFinishReason(response))
);
}
}
Loading

0 comments on commit f185438

Please sign in to comment.