Skip to content

Commit

Permalink
Gemini updates (langchain4j#1278)
Browse files Browse the repository at this point in the history
Gemini updates:
* update to latest Java SDK version
* langchain4j#1269
* langchain4j#1270
* langchain4j#1208
* langchain4j#1399 
* langchain4j#1397 
* langchain4j#1182
* langchain4j#828
* fixes parallel function calling which wasn't working properly in the
previous release
* refactored a bit the `generate()` method to have a single entry point
and less duplication

Vertex AI embedding model:
* add new task types (question answering and fact verification)

Imagen image model:
* support more configuration parameters
* langchain4j#1367
  • Loading branch information
glaforge authored Jul 3, 2024
1 parent abc957a commit c154203
Show file tree
Hide file tree
Showing 21 changed files with 1,741 additions and 106 deletions.
8 changes: 8 additions & 0 deletions docs/docs/tutorials/5-ai-services.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ AzureOpenAiChatModel.builder()
.build();
```

- For Vertex AI Gemini:
```java
VertexAiGeminiChatModel.builder()
...
.responseMimeType("application/json")
.build();
```

- For Mistral AI:
```java
MistralAiChatModel.builder()
Expand Down
21 changes: 20 additions & 1 deletion langchain4j-vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-vertexai</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>proto-google-cloud-vertexai-v1</artifactId>
<version>1.6.0</version>
</dependency>

<dependency>
Expand Down Expand Up @@ -94,7 +105,7 @@
<artifactId>libraries-bom</artifactId>
<scope>import</scope>
<type>pom</type>
<version>26.39.0</version>
<version>26.42.0</version>
</dependency>
</dependencies>
</dependencyManagement>
Expand All @@ -108,6 +119,14 @@
<skipTests>${skipVertexAiGeminiITs}</skipTests>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
<configuration>
<parameters>true</parameters>
</configuration>
</plugin>
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,100 @@

import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.Part;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

class ContentsMapper {
static class InstructionAndContent {
public Content systemInstruction = null;
public List<Content> contents = new ArrayList<>();

@Override
public String toString() {
return "InstructionAndContent {\n" +
" systemInstruction = " + systemInstruction +
",\n contents = " + contents +
"\n}";
}
}

static InstructionAndContent splitInstructionAndContent(List<ChatMessage> messages) {
InstructionAndContent instructionAndContent = new InstructionAndContent();

List<Part> sysInstructionParts = new ArrayList<>();

for (ChatMessage message : messages) {
if (message instanceof SystemMessage) {
sysInstructionParts.addAll(PartsMapper.map(message));
List<ToolExecutionResultMessage> executionResultMessages = new ArrayList<>();

for (int msgIdx = 0; msgIdx < messages.size(); msgIdx++) {
ChatMessage message = messages.get(msgIdx);
boolean isLastMessage = msgIdx == messages.size() - 1;

if (message instanceof ToolExecutionResultMessage) {
ToolExecutionResultMessage toolResult = (ToolExecutionResultMessage) message;
if (isLastMessage) {
// if there's no accumulated tool results, add it right away to the list of messages
if (executionResultMessages.isEmpty()) {
instructionAndContent.contents.add(createContent(message));
} else { // otherwise add to the list, and create the new user message with all the tool results
executionResultMessages.add(toolResult);
instructionAndContent.contents.add(createToolExecutionResultContent(executionResultMessages));
}
} else { // not the last message, so just accumulate the new tool result
executionResultMessages.add(toolResult);
}
} else {
instructionAndContent.contents.add(Content.newBuilder()
.setRole(RoleMapper.map(message.type()))
.addAllParts(PartsMapper.map(message))
.build());
// if we're done with tool results and encounter a new user or AI message
// then bundle all the tool results into a new user message
if (!executionResultMessages.isEmpty()) {
instructionAndContent.contents.add(createToolExecutionResultContent(executionResultMessages));
executionResultMessages = new ArrayList<>();
}

// directly add user and AI messages to the list
if (message instanceof UserMessage || message instanceof AiMessage) {
instructionAndContent.contents.add(createContent(message));
} else if (message instanceof SystemMessage) { // save system messages separately
sysInstructionParts.addAll(PartsMapper.map(message));
}
}
}

instructionAndContent.systemInstruction = Content.newBuilder()
.setRole("system")
.addAllParts(sysInstructionParts)
.build();
// if there are system instructions, collect them together into one system instruction Content
if (!sysInstructionParts.isEmpty()) {
instructionAndContent.systemInstruction = Content.newBuilder()
.setRole("system")
.addAllParts(sysInstructionParts)
.build();
}

return instructionAndContent;
}

// transform a LangChain4j ChatMessage into a Gemini Content
private static Content createContent(ChatMessage message) {
return Content.newBuilder()
.setRole(RoleMapper.map(message.type()))
.addAllParts(PartsMapper.map(message))
.build();
}

// transform a list of LangChain4j tool execution results
// into a user message made of multiple Gemini Parts
private static Content createToolExecutionResultContent(List<ToolExecutionResultMessage> executionResultMessages) {
return Content.newBuilder()
.setRole(RoleMapper.map(ChatMessageType.TOOL_EXECUTION_RESULT))
.addAllParts(
executionResultMessages.stream()
.map(PartsMapper::map)
.flatMap(List::stream)
.collect(Collectors.toList()))
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package dev.langchain4j.model.vertexai;

/**
* Possible harm categories for the generation of responses that have been blocked by the model.
*/
public enum HarmCategory {
HARM_CATEGORY_UNSPECIFIED,
HARM_CATEGORY_HATE_SPEECH,
HARM_CATEGORY_DANGEROUS_CONTENT,
HARM_CATEGORY_HARASSMENT,
HARM_CATEGORY_SEXUALLY_EXPLICIT,
UNRECOGNIZED
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import com.google.cloud.vertexai.api.FunctionResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;

import java.net.URI;
import java.util.Base64;
Expand All @@ -19,6 +25,7 @@
import static com.google.cloud.vertexai.generativeai.PartMaker.fromMimeTypeAndData;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.quoted;
import static dev.langchain4j.internal.Utils.readBytes;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
Expand Down Expand Up @@ -86,7 +93,12 @@ static List<Part> map(ChatMessage message) {
try {
JsonFormat.parser().merge(functionResponseTextAsMap, structBuilder);
} catch (InvalidProtocolBufferException e2) {
throw new RuntimeException(e);
String functionResponseTextWithQuotesAsMap = "{\"result\":" + quoted(functionResponseText) + "}";
try {
JsonFormat.parser().merge(functionResponseTextWithQuotesAsMap, structBuilder);
} catch (InvalidProtocolBufferException e3) {
throw new RuntimeException(e3);
}
}
}
Struct responseStruct = structBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
import com.google.cloud.vertexai.api.Retrieval;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.VertexAISearch;

/**
* Ground Gemini responses with Google Search web results
* or with Vertex AI Search datastores
*/
class ResponseGrounding {

static Tool googleSearchTool() {
return Tool.newBuilder()
.setGoogleSearchRetrieval(
GoogleSearchRetrieval.newBuilder()
// .setDisableAttribution(false)
.build())
.build();
}

/**
* @param datastore fully qualified name of the Vertex Search datastore, with the following format
* "projects/PROJECT_ID/locations/global/collections/default_collection/dataStores/DATASTORE_NAME"
*/
static Tool vertexAiSearch(String datastore) {
return Tool.newBuilder()
.setRetrieval(
Retrieval.newBuilder()
.setVertexAiSearch(VertexAISearch.newBuilder().setDatastore(datastore))
.setDisableAttribution(false))
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.SafetySetting;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Maps between Vertex AI <code>SafetSetting</code> and LangChain4j
* <code>HarmCategoty</code> and <code>SafetyThreshold</code>
*/
class SafetySettingsMapper {
static List<SafetySetting> mapSafetySettings(Map<dev.langchain4j.model.vertexai.HarmCategory, SafetyThreshold> safetySettingsMap) {
return safetySettingsMap.entrySet().stream()
.map(entry -> {
SafetySetting.Builder safetySettingBuilder = SafetySetting.newBuilder();
safetySettingBuilder.setCategory(map(entry.getKey()));
safetySettingBuilder.setThreshold(map(entry.getValue()));
return safetySettingBuilder.build();
})
.collect(Collectors.toList());
}

private static HarmCategory map(dev.langchain4j.model.vertexai.HarmCategory harmCategory) {
return HarmCategory.valueOf(harmCategory.name());
}

private static HarmBlockThreshold map(SafetyThreshold safetyThreshold) {
return HarmBlockThreshold.valueOf(safetyThreshold.name());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package dev.langchain4j.model.vertexai;

/**
* Safety thresholds, for the harm categories for the generation of responses that have been blocked by the model.
*/
public enum SafetyThreshold {
HARM_BLOCK_THRESHOLD_UNSPECIFIED,
BLOCK_LOW_AND_ABOVE,
BLOCK_MEDIUM_AND_ABOVE,
BLOCK_ONLY_HIGH,
BLOCK_NONE,
UNRECOGNIZED
}
Loading

0 comments on commit c154203

Please sign in to comment.