Skip to content

Commit

Permalink
Sanitize messages before sending to Qwen models (langchain4j#1423)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1326 

## Change
Construct the message list according to the constraints of Qwen models.
(Otherwise, the model provider will return an exception.)
1. If there is a system message, it should be the first message.
Otherwise, the user message is the first message.
2. User/Tool-execution-result messages and AI messages should alternate.
Use the newest one when duplicated.
3. The last message in the message list should be a
user/tool-execution-result message. This serves as the model query/input
for the current round.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
  • Loading branch information
jiangsier-xyz authored Jul 10, 2024
1 parent 854f76a commit 2cdfb4a
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 14 deletions.
2 changes: 1 addition & 1 deletion langchain4j-dashscope/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.14.7</version>
<version>2.15.1</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.MultiModalMessage;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.tools.*;
import com.alibaba.dashscope.utils.JsonUtils;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
Expand All @@ -18,6 +19,7 @@
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -26,18 +28,22 @@
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import static com.alibaba.dashscope.common.Role.*;
import static dev.langchain4j.data.message.ChatMessageType.*;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.stream.Collectors.toList;

@Slf4j
class QwenHelper {

static List<Message> toQwenMessages(List<ChatMessage> messages) {
return messages.stream()
return sanitizeMessages(messages)
.stream()
.map(QwenHelper::toQwenMessage)
.collect(toList());
}
Expand Down Expand Up @@ -97,7 +103,7 @@ static String toolCallIdFrom(ChatMessage message) {
}

static List<ToolCallBase> toolCallsFrom(ChatMessage message) {
if (message.type() == ChatMessageType.AI && ((AiMessage) message).hasToolExecutionRequests()) {
if (message.type() == AI && ((AiMessage) message).hasToolExecutionRequests()) {
return toToolCalls(((AiMessage) message).toolExecutionRequests());
}
return null;
Expand Down Expand Up @@ -189,14 +195,14 @@ private static String saveImageAsTemporaryFile(String base64Data, String mimeTyp
}

static String roleFrom(ChatMessage message) {
if (message.type() == ChatMessageType.AI) {
return ASSISTANT.getValue();
} else if (message.type() == ChatMessageType.SYSTEM) {
return SYSTEM.getValue();
if (message.type() == AI) {
return Role.ASSISTANT.getValue();
} else if (message.type() == SYSTEM) {
return Role.SYSTEM.getValue();
} else if (message.type() == ChatMessageType.TOOL_EXECUTION_RESULT) {
return TOOL.getValue();
return Role.TOOL.getValue();
} else {
return USER.getValue();
return Role.USER.getValue();
}
}

Expand Down Expand Up @@ -402,4 +408,67 @@ private static ToolCallBase toToolCall(ToolExecutionRequest toolExecutionRequest
toolCallFunction.setFunction(callFunction);
return toolCallFunction;
}

static List<ChatMessage> sanitizeMessages(List<ChatMessage> messages) {
LinkedList<ChatMessage> sanitizedMessages = messages.stream()
.reduce(new LinkedList<>(), messageAccumulator(), messageCombiner());

// Ensure the last message is a user/tool_execution_result message
while(!sanitizedMessages.isEmpty() && !isInputMessageType(sanitizedMessages.getLast().type())) {
ChatMessage removedMessage = sanitizedMessages.removeLast();
log.warn("The last message should be a user/tool_execution_result message, but found: {}", removedMessage);
}

return sanitizedMessages;
}

private static BiFunction<LinkedList<ChatMessage>, ChatMessage, LinkedList<ChatMessage>> messageAccumulator() {
return (acc, message) -> {
ChatMessageType type = message.type();
if (acc.isEmpty()) {
// Ensure the first message is a system message or a user message.
if (type == SYSTEM || type == USER) {
acc.add(message);
} else {
log.warn("The first message should be a system message or a user message, but found: {}", message);
}
return acc;
}

if (type == SYSTEM) {
// Ensure the system message is the first message.
log.warn("The system message should be the first message. Drop existed messages: {}", acc);
acc.clear();
acc.add(message);
return acc;
}

ChatMessageType lastType = acc.getLast().type();
if (lastType == SYSTEM && type != USER) {
// The first non-system message must be a user message.
log.warn("The first non-system message must be a user message, but found: {}", message);
return acc;
}

if (isInputMessageType(type) == isInputMessageType(lastType)) {
// The list must be user/tool_execution_result and ai alternating messages.
// Use the newest one when duplicated.
ChatMessage removedMessage = acc.removeLast();
log.warn("User/Tool-execution-result messages and AI messages should alternate. Drop duplicated message: {}", removedMessage);
}

acc.add(message);
return acc;
};
}

private static BinaryOperator<LinkedList<ChatMessage>> messageCombiner() {
return (acc1, acc2) -> {
throw new UnsupportedOperationException("Parallel stream not supported");
};
}

private static boolean isInputMessageType(ChatMessageType messageType) {
return messageType == USER || messageType == TOOL_EXECUTION_RESULT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.LinkedList;
import java.util.List;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
Expand Down Expand Up @@ -236,4 +235,66 @@ public void should_send_multimodal_image_data_and_receive_response(String modelN

assertThat(response.content().text()).containsIgnoringCase("parrot");
}

@Test
public void should_sanitize_messages() {
List<ChatMessage> messages = new LinkedList<>();

// 1. The system message should be the first message.
// 2. User/Tool-execution-result messages and AI messages should alternate.
// 3. The last message in the message list should be a user message. This serves as the model query/input for the current round.

messages.add(SystemMessage.from("System message 1, which should be discarded"));
messages.add(UserMessage.from("User message 1, which should be discarded"));
messages.add(SystemMessage.from("System message 2"));

messages.add(AiMessage.from("AI message 1, which should be discarded"));
messages.add(ToolExecutionResultMessage.from(ToolExecutionRequest.builder().build(),
"Tool execution result 1, which should be discards"));
messages.add(UserMessage.from("User message 2, which should be discarded"));
messages.add(UserMessage.from("User message 3"));

messages.add(AiMessage.from("AI message 2, which should be discarded"));
messages.add(AiMessage.from("AI message 3"));

messages.add(ToolExecutionResultMessage.from(ToolExecutionRequest.builder().build(),
"Tool execution result 2, which should be discards"));
messages.add(ToolExecutionResultMessage.from(ToolExecutionRequest.builder().build(),
"Tool execution result 3"));

messages.add(AiMessage.from("AI message 4"));

messages.add(UserMessage.from("User message 4, which should be discards"));
messages.add(UserMessage.from("User message 5"));

messages.add(AiMessage.from("AI message 5, which should be discards"));

// The result should be in the following order:
// 1. System message
// 2. User message
// 3. AI message
// 4. Tool execution result message
// 5. AI message
// 6. User message
List<ChatMessage> sanitizedMessages = QwenHelper.sanitizeMessages(messages);
assertThat(sanitizedMessages).hasSize(6);

assertThat(sanitizedMessages.get(0)).isInstanceOf(SystemMessage.class);
assertThat(((SystemMessage) sanitizedMessages.get(0)).text()).isEqualTo("System message 2");

assertThat(sanitizedMessages.get(1)).isInstanceOf(UserMessage.class);
assertThat(((UserMessage) sanitizedMessages.get(1)).singleText()).isEqualTo("User message 3");

assertThat(sanitizedMessages.get(2)).isInstanceOf(AiMessage.class);
assertThat(((AiMessage) sanitizedMessages.get(2)).text()).isEqualTo("AI message 3");

assertThat(sanitizedMessages.get(3)).isInstanceOf(ToolExecutionResultMessage.class);
assertThat(((ToolExecutionResultMessage) sanitizedMessages.get(3)).text()).isEqualTo("Tool execution result 3");

assertThat(sanitizedMessages.get(4)).isInstanceOf(AiMessage.class);
assertThat(((AiMessage) sanitizedMessages.get(4)).text()).isEqualTo("AI message 4");

assertThat(sanitizedMessages.get(5)).isInstanceOf(UserMessage.class);
assertThat(((UserMessage) sanitizedMessages.get(5)).singleText()).isEqualTo("User message 5");
}
}

0 comments on commit 2cdfb4a

Please sign in to comment.