Skip to content

Commit

Permalink
Gemini function calling support (langchain4j#692)
Browse files Browse the repository at this point in the history
- Add Gemini function calling support, in both streaming and
non-streaming models
- Upgrade the Gemini Java SDK

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced functionality for converting tool execution requests into
function calls and vice versa, enhancing interaction capabilities.
- Expanded message handling in chat models to support different message
types and tool execution results.
- Added support for generating chat messages based on tool
specifications, allowing for more dynamic conversations.

- **Enhancements**
- Improved type conversion and message construction processes for better
efficiency and reliability.

- **Tests**
- Added comprehensive test coverage for new functionalities, ensuring
robustness and reliability of conversions and message handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
glaforge authored Mar 11, 2024
1 parent f565efc commit 7fa0adc
Show file tree
Hide file tree
Showing 10 changed files with 607 additions and 101 deletions.
2 changes: 1 addition & 1 deletion langchain4j-vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
<artifactId>libraries-bom</artifactId>
<scope>import</scope>
<type>pom</type>
<version>26.29.0</version>
<version>26.34.0</version>
</dependency>
</dependencies>
</dependencyManagement>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.*;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutionRequestUtil;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

class FunctionCallHelper {
static Type fromType(String type) {
//TODO: is it covering all the types correctly?
switch (type) {
case "string":
return Type.STRING;
case "integer":
return Type.INTEGER;
case "boolean":
return Type.BOOLEAN;
case "number":
return Type.NUMBER;
case "array":
return Type.ARRAY;
case "object":
return Type.OBJECT;
default:
return Type.TYPE_UNSPECIFIED;
}
}

static FunctionCall fromToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
FunctionCall.Builder fnCallBuilder = FunctionCall.newBuilder()
.setName(toolExecutionRequest.name());

Struct.Builder structBuilder = Struct.newBuilder();
try {
JsonFormat.parser().merge(toolExecutionRequest.arguments(), structBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
Struct argsStruct = structBuilder.build();
fnCallBuilder.setArgs(argsStruct);

return fnCallBuilder.build();
}

static List<ToolExecutionRequest> fromFunctionCalls(List<FunctionCall> functionCalls) {
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>();

for (FunctionCall functionCall : functionCalls) {
ToolExecutionRequest.Builder builder = ToolExecutionRequest.builder()
.name(functionCall.getName());

Map<String, Object> callArgsMap = new HashMap<>();
Struct callArgs = functionCall.getArgs();
Map<String, Value> callArgsFieldsMap = callArgs.getFieldsMap();
callArgsFieldsMap.forEach((key, value) -> callArgsMap.put(key, unwrapProtoValue(value)));

String serializedArgsMap = ToolExecutionRequestUtil.GSON.toJson(callArgsMap);
builder.arguments(serializedArgsMap);

toolExecutionRequests.add(builder.build());
}

return toolExecutionRequests;
}

static Object unwrapProtoValue(Value value) {
Object unwrappedValue;
switch (value.getKindCase()) {
case NUMBER_VALUE:
unwrappedValue = value.getNumberValue();
break;
case STRING_VALUE:
unwrappedValue = value.getStringValue();
break;
case BOOL_VALUE:
unwrappedValue = value.getBoolValue();
break;
case STRUCT_VALUE:
HashMap<String, Object> mapForStruct = new HashMap<>();
value.getStructValue().getFieldsMap().forEach((key, val) -> mapForStruct.put(key, unwrapProtoValue(val)));
unwrappedValue = mapForStruct;
break;
case LIST_VALUE:
unwrappedValue = value.getListValue().getValuesList().stream().map(FunctionCallHelper::unwrapProtoValue).collect(Collectors.toList());
break;
default: // NULL_VALUE, KIND_NOT_SET, and default
unwrappedValue = null;
break;
}
return unwrappedValue;
}

static Tool convertToolSpecifications(List<ToolSpecification> toolSpecifications) {
Tool.Builder tool = Tool.newBuilder();

for (ToolSpecification toolSpecification : toolSpecifications) {
FunctionDeclaration.Builder fnBuilder = FunctionDeclaration.newBuilder()
.setName(toolSpecification.name())
.setDescription(toolSpecification.description());

Schema.Builder schema = Schema.newBuilder().setType(Type.OBJECT);

ToolParameters parameters = toolSpecification.parameters();
for (String paramName : parameters.required()) {
schema.addRequired(paramName);
}
parameters.properties().forEach((paramName, paramProps) -> {
//TODO: is it covering all types & cases of tool parameters? (array & object in particular)
Type type = fromType((String) paramProps.getOrDefault("type", Type.TYPE_UNSPECIFIED));

String description = (String) paramProps.getOrDefault("description", "");

schema.putProperties(paramName, Schema.newBuilder()
.setDescription(description)
.setType(type)
.build());
});
fnBuilder.setParameters(schema.build());
tool.addFunctionDeclarations(fnBuilder.build());
}

return tool.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.FunctionResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;

Expand All @@ -10,7 +14,7 @@
import java.util.List;
import java.util.Map;

import static com.google.cloud.vertexai.generativeai.preview.PartMaker.fromMimeTypeAndData;
import static com.google.cloud.vertexai.generativeai.PartMaker.fromMimeTypeAndData;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.readBytes;
Expand All @@ -36,14 +40,45 @@ class PartsMapper {
}

static List<Part> map(ChatMessage message) {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;

if (aiMessage.hasToolExecutionRequests()) {
return singletonList(Part.newBuilder()
.setFunctionCall(
//TODO: handling one function call, but can there be several?

FunctionCallHelper.fromToolExecutionRequest(aiMessage.toolExecutionRequests().get(0))
)
.build());
} else {
return singletonList(Part.newBuilder()
.setText(aiMessage.text())
.build());
}
} else
if (message instanceof UserMessage) {
return ((UserMessage) message).contents().stream()
.map(PartsMapper::map)
.collect(toList());
} else if (message instanceof AiMessage) {
.map(PartsMapper::map)
.collect(toList());
} else if (message instanceof ToolExecutionResultMessage) {
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
String functionResponseText = toolExecutionResultMessage.text();

Struct.Builder structBuilder = Struct.newBuilder();
try {
JsonFormat.parser().merge(functionResponseText, structBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
Struct responseStruct = structBuilder.build();

return singletonList(Part.newBuilder()
.setText(((AiMessage) message).text())
.build());
.setFunctionResponse(FunctionResponse.newBuilder()
.setName(toolExecutionResultMessage.toolName())
.setResponse(responseStruct)
.build())
.build());
} else {
throw illegalArgument(message.type() + " message is not supported by Gemini");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class RoleMapper {

static String map(ChatMessageType type) {
switch (type) {
case TOOL_EXECUTION_RESULT:
case USER:
return "user";
case AI:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.api.Candidate;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.generativeai.preview.ResponseHandler;
import com.google.cloud.vertexai.api.*;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

class StreamingChatResponseBuilder {

private final StringBuffer contentBuilder = new StringBuffer();

private final List<FunctionCall> functionCalls = new ArrayList<>();

private volatile TokenUsage tokenUsage;
private volatile FinishReason finishReason;

Expand All @@ -26,7 +30,19 @@ void append(GenerateContentResponse partialResponse) {
return;
}

contentBuilder.append(ResponseHandler.getText(partialResponse));
List<FunctionCall> functionCalls = candidates.stream()
.map(Candidate::getContent)
.map(Content::getPartsList)
.flatMap(List::stream)
.filter(Part::hasFunctionCall)
.map(Part::getFunctionCall)
.collect(Collectors.toList());

if (!functionCalls.isEmpty()) {
this.functionCalls.addAll(functionCalls);
} else {
contentBuilder.append(ResponseHandler.getText(partialResponse));
}

if (partialResponse.hasUsageMetadata()) {
tokenUsage = TokenUsageMapper.map(partialResponse.getUsageMetadata());
Expand All @@ -39,10 +55,18 @@ void append(GenerateContentResponse partialResponse) {
}

Response<AiMessage> build() {
return Response.from(
if (!functionCalls.isEmpty()) {
return Response.from(
AiMessage.from(FunctionCallHelper.fromFunctionCalls(functionCalls)),
tokenUsage,
finishReason
);
} else {
return Response.from(
AiMessage.from(contentBuilder.toString()),
tokenUsage,
finishReason
);
);
}
}
}
Loading

0 comments on commit 7fa0adc

Please sign in to comment.