Skip to content

Commit

Permalink
Support ToolExecutionResultMessage on Qwen Models (langchain4j#1260)
Browse files Browse the repository at this point in the history
The current implementation supports the tool calls plan, but does not
include ToolExecutionResultMessage. This makes it impossible to feed
back the tool execution results to the model. This patch includes
support for ToolExecutionResultMessage.

<!-- Thank you so much for your contribution! -->
<!-- Please fill in all the sections below. -->

<!-- Please open the PR as a draft initially. Once it is reviewed and
approved, we will ask you to add documentation and examples. -->
<!-- Please note that PRs with breaking changes will be rejected. -->
<!-- Please note that PRs without tests will be rejected. -->

<!-- Please note that PRs will be reviewed based on the priority of the
issues they address. -->
<!-- We ask for your patience. We are doing our best to review your PR
as quickly as possible. -->
<!-- Please refrain from pinging and asking when it will be reviewed.
Thank you for understanding! -->


## Issue
<!-- Please paste the link to the issue this PR is addressing. For
example: langchain4j#1012 -->


## Change
<!-- Please describe the changes you made. -->


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
  • Loading branch information
jiangsier-xyz authored Jun 11, 2024
1 parent eba2b1c commit b74270b
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 121 deletions.
2 changes: 1 addition & 1 deletion docs/docs/integrations/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ of course some LLM providers offer large multimodal model (accepting text or ima
| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | ||| || ||
| [Google Vertex AI](/integrations/language-models/google-palm) ||| ||| | |
| [Mistral AI](/integrations/language-models/mistral-ai) | |||| | ||
| [DashScope](/integrations/language-models/dashscope) | |||| | | |
| [DashScope](/integrations/language-models/dashscope) | ||| | | | |
| [LocalAI](/integrations/language-models/local-ai) | |||| | ||
| [Ollama](/integrations/language-models/ollama) | |||| | | |
| Cohere | | | | | || |
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/integrations/language-models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sidebar_position: 0
| [Anthropic](/integrations/language-models/anthropic) |||| ||
| [Azure OpenAI](/integrations/language-models/azure-open-ai) |||| | |
| [ChatGLM](/integrations/language-models/chatglm) | | | | | |
| [DashScope](/integrations/language-models/dashscope) || | | | |
| [DashScope](/integrations/language-models/dashscope) || | | | |
| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) |||| | |
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | ||
| [Hugging Face](/integrations/language-models/hugging-face) | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,28 @@ protected QwenChatModel(String baseUrl,

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return isMultimodalModel ? generateByMultimodalModel(messages) : generateByNonMultimodalModel(messages, null, null);
return isMultimodalModel ?
generateByMultimodalModel(messages, null, null) :
generateByNonMultimodalModel(messages, null, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generateByNonMultimodalModel(messages, toolSpecifications, null);
return isMultimodalModel ?
generateByMultimodalModel(messages, toolSpecifications, null) :
generateByNonMultimodalModel(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generateByNonMultimodalModel(messages, null, toolSpecification);
return isMultimodalModel ?
generateByMultimodalModel(messages, null, toolSpecification) :
generateByNonMultimodalModel(messages, null, toolSpecification);
}

private Response<AiMessage> generateByNonMultimodalModel(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
private Response<AiMessage> generateByNonMultimodalModel(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
try {
GenerationParam.GenerationParamBuilder<?, ?> builder = GenerationParam.builder()
.apiKey(apiKey)
Expand All @@ -127,7 +131,7 @@ private Response<AiMessage> generateByNonMultimodalModel(
builder.tools(toToolFunctions(toolSpecifications));
} else if (toolThatMustBeExecuted != null) {
builder.tools(toToolFunctions(Collections.singleton(toolThatMustBeExecuted)));
builder.toolChoice(buildToolChoiceStrategy(toolThatMustBeExecuted));
builder.toolChoice(toToolFunction(toolThatMustBeExecuted));
}

GenerationResult generationResult = generation.call(builder.build());
Expand All @@ -142,7 +146,13 @@ private Response<AiMessage> generateByNonMultimodalModel(
}
}

private Response<AiMessage> generateByMultimodalModel(List<ChatMessage> messages) {
private Response<AiMessage> generateByMultimodalModel(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
if (toolThatMustBeExecuted != null || !isNullOrEmpty(toolSpecifications)) {
throw new IllegalArgumentException("Tools are currently not supported by this model");
}

try {
MultiModalConversationParam param = MultiModalConversationParam.builder()
.apiKey(apiKey)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;

Expand All @@ -31,8 +30,7 @@

import static com.alibaba.dashscope.common.Role.*;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.stream.Collectors.toList;

class QwenHelper {
Expand All @@ -53,6 +51,9 @@ static Message toQwenMessage(ChatMessage message) {
return Message.builder()
.role(roleFrom(message))
.content(toSingleText(message))
.name(nameFrom(message))
.toolCallId(toolCallIdFrom(message))
.toolCalls(toolCallsFrom(message))
.build();
}

Expand All @@ -66,7 +67,7 @@ static String toSingleText(ChatMessage message) {
.map(TextContent::text)
.collect(Collectors.joining("\n"));
case AI:
return ((AiMessage) message).text();
return ((AiMessage) message).hasToolExecutionRequests() ? "" : ((AiMessage) message).text();
case SYSTEM:
return ((SystemMessage) message).text();
case TOOL_EXECUTION_RESULT:
Expand All @@ -76,6 +77,31 @@ static String toSingleText(ChatMessage message) {
}
}

static String nameFrom(ChatMessage message) {
switch (message.type()) {
case USER:
return ((UserMessage) message).name();
case TOOL_EXECUTION_RESULT:
return ((ToolExecutionResultMessage) message).toolName();
default:
return null;
}
}

static String toolCallIdFrom(ChatMessage message) {
if (message.type() == ChatMessageType.TOOL_EXECUTION_RESULT) {
return ((ToolExecutionResultMessage) message).id();
}
return null;
}

static List<ToolCallBase> toolCallsFrom(ChatMessage message) {
if (message.type() == ChatMessageType.AI && ((AiMessage) message).hasToolExecutionRequests()) {
return toToolCalls(((AiMessage) message).toolExecutionRequests());
}
return null;
}

static List<MultiModalMessage> toQwenMultiModalMessages(List<ChatMessage> messages) {
return messages.stream()
.map(QwenHelper::toQwenMultiModalMessage)
Expand Down Expand Up @@ -162,10 +188,12 @@ private static String saveImageAsTemporaryFile(String base64Data, String mimeTyp
}

static String roleFrom(ChatMessage message) {
if (message instanceof AiMessage) {
if (message.type() == ChatMessageType.AI) {
return ASSISTANT.getValue();
} else if (message instanceof SystemMessage) {
} else if (message.type() == ChatMessageType.SYSTEM) {
return SYSTEM.getValue();
} else if (message.type() == ChatMessageType.TOOL_EXECUTION_RESULT) {
return TOOL.getValue();
} else {
return USER.getValue();
}
Expand Down Expand Up @@ -236,19 +264,23 @@ static TokenUsage tokenUsageFrom(MultiModalConversationResult result) {
}

static FinishReason finishReasonFrom(GenerationResult result) {
String finishReason = Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getChoices)
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(Choice::getFinishReason)
.orElse("");
Choice choice = result.getOutput().getChoices().get(0);
String finishReason = choice.getFinishReason();
if (finishReason == null) {
if (isNullOrEmpty(choice.getMessage().getToolCalls())) {
return null;
}
// Upon observation, when tool_calls occur, the returned finish_reason may be null, not "tool_calls".
finishReason = "tool_calls";
}

switch (finishReason) {
case "stop":
return STOP;
case "length":
return LENGTH;
case "tool_calls":
return TOOL_EXECUTION;
default:
return null;
}
Expand Down Expand Up @@ -278,65 +310,29 @@ public static boolean isMultimodalModel(String modelName) {
return modelName.startsWith("qwen-vl");
}

/**
* build ToolFunction(ToolBase) coll from ToolSpecification coll
*
* @param toolSpecifications {@link ToolSpecification}
* @return {@link ToolFunction}
*/
static List<ToolBase> toToolFunctions(Collection<ToolSpecification> toolSpecifications) {
if (isNullOrEmpty(toolSpecifications)) {
return Collections.emptyList();
}

return toolSpecifications.stream()
.map(tool -> FunctionDefinition
.builder()
.name(tool.name())
.description(tool.description())
.parameters(toParameters(tool.parameters()))
.build()
)
.map(definition -> (ToolBase) ToolFunction
.builder()
.function(definition)
.build()
)
.map(QwenHelper::toToolFunction)
.collect(Collectors.toList());
}

private static JsonObject toParameters(ToolParameters toolParameters) {
QwenParameters qwenParameters = QwenParameters.from(toolParameters);
return JsonUtils.parseString(JsonUtils.toJson(qwenParameters)).getAsJsonObject();
}

/**
* Because of the interface definition, only implement the strategy of "must be called" here.{@link ChatLanguageModel}
*
* @param toolThatMustBeExecuted {@link ToolSpecification}
* @return tool choice strategy
* More details are available <a href="https://help.aliyun.com/zh/dashscope/developer-reference/api-details">here</a>.
*/
static ToolChoiceStrategy buildToolChoiceStrategy(ToolSpecification toolThatMustBeExecuted) {
return new ToolChoiceStrategy(new ToolChoiceFunction(toolThatMustBeExecuted.name()));
}

private static class ToolChoiceStrategy {
private final String type = "function";

private final ToolChoiceFunction function;

public ToolChoiceStrategy(ToolChoiceFunction function) {
this.function = function;
}
static ToolBase toToolFunction(ToolSpecification toolSpecification) {
FunctionDefinition functionDefinition = FunctionDefinition.builder()
.name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(toParameters(toolSpecification.parameters()))
.build();
return ToolFunction.builder().function(functionDefinition).build();
}

private static class ToolChoiceFunction {
private final String name;

public ToolChoiceFunction(String name) {
this.name = name;
}
private static JsonObject toParameters(ToolParameters toolParameters) {
return toolParameters == null ?
JsonUtils.toJsonObject(Collections.emptyMap()) :
JsonUtils.toJsonObject(toolParameters);
}

static AiMessage aiMessageFrom(GenerationResult result) {
Expand Down Expand Up @@ -375,4 +371,20 @@ static boolean isFunctionToolCalls(GenerationResult result) {
.map(Message::getToolCalls);
return toolCallBases.isPresent() && !isNullOrEmpty(toolCallBases.get());
}

private static List<ToolCallBase> toToolCalls(Collection<ToolExecutionRequest> toolExecutionRequests) {
return toolExecutionRequests.stream()
.map(QwenHelper::toToolCall)
.collect(toList());
}

private static ToolCallBase toToolCall(ToolExecutionRequest toolExecutionRequest) {
ToolCallFunction toolCallFunction = new ToolCallFunction();
toolCallFunction.setId(toolExecutionRequest.id());
ToolCallFunction.CallFunction callFunction = toolCallFunction.new CallFunction();
callFunction.setName(toolExecutionRequest.name());
callFunction.setArguments(toolExecutionRequest.arguments());
toolCallFunction.setFunction(callFunction);
return toolCallFunction;
}
}

This file was deleted.

Loading

0 comments on commit b74270b

Please sign in to comment.