Skip to content

Commit

Permalink
Feature langchain4j#1210 Support Tongyi Qianwen(QwenChatModel) functi…
Browse files Browse the repository at this point in the history
…on call (langchain4j#1254)

<!-- Thank you so much for your contribution! -->
<!-- Please fill in all the sections below. -->

<!-- Please open the PR as a draft initially. Once it is reviewed and
approved, we will ask you to add documentation and examples. -->
<!-- Please note that PRs with breaking changes will be rejected. -->
<!-- Please note that PRs without tests will be rejected. -->

<!-- Please note that PRs will be reviewed based on the priority of the
issues they address. -->
<!-- We ask for your patience. We are doing our best to review your PR
as quickly as possible. -->
<!-- Please refrain from pinging and asking when it will be reviewed.
Thank you for understanding! -->


## Issue
<!-- Please paste the link to the issue this PR is addressing. For
example: langchain4j#1012 -->


## Change
<!-- Please describe the changes you made. -->


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)


## Checklist for adding new model integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for adding new embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends
from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT`
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for changing existing embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have manually verified that the
`{NameOfIntegration}EmbeddingStore` works correctly with the data
persisted using the latest released version of LangChain4j
  • Loading branch information
Kugaaa authored Jun 10, 2024
1 parent 6b78a77 commit d580506
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 10 deletions.
2 changes: 1 addition & 1 deletion langchain4j-dashscope/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.14.4</version>
<version>2.14.7</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.exception.UploadFileException;
import com.alibaba.dashscope.protocol.Protocol;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
Expand All @@ -18,9 +19,11 @@
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.util.Collections;
import java.util.List;

import static com.alibaba.dashscope.aigc.conversation.ConversationParam.ResultFormat.MESSAGE;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.dashscope.QwenHelper.*;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;

Expand Down Expand Up @@ -84,10 +87,24 @@ protected QwenChatModel(String baseUrl,

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return isMultimodalModel ? generateByMultimodalModel(messages) : generateByNonMultimodalModel(messages);
return isMultimodalModel ? generateByMultimodalModel(messages) : generateByNonMultimodalModel(messages, null, null);
}

private Response<AiMessage> generateByNonMultimodalModel(List<ChatMessage> messages) {
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generateByNonMultimodalModel(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generateByNonMultimodalModel(messages, null, toolSpecification);
}

private Response<AiMessage> generateByNonMultimodalModel(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
try {
GenerationParam.GenerationParamBuilder<?, ?> builder = GenerationParam.builder()
.apiKey(apiKey)
Expand All @@ -106,11 +123,20 @@ private Response<AiMessage> generateByNonMultimodalModel(List<ChatMessage> messa
builder.stopStrings(stops);
}

if (!isNullOrEmpty(toolSpecifications)) {
builder.tools(toToolFunctions(toolSpecifications));
} else if (toolThatMustBeExecuted != null) {
builder.tools(toToolFunctions(Collections.singleton(toolThatMustBeExecuted)));
builder.toolChoice(buildToolChoiceStrategy(toolThatMustBeExecuted));
}

GenerationResult generationResult = generation.call(builder.build());
String answer = answerFrom(generationResult);

return Response.from(AiMessage.from(answer),
tokenUsageFrom(generationResult), finishReasonFrom(generationResult));
return Response.from(
aiMessageFrom(generationResult),
tokenUsageFrom(generationResult),
finishReasonFrom(generationResult)
);
} catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.MultiModalMessage;
import com.alibaba.dashscope.tools.*;
import com.alibaba.dashscope.utils.JsonUtils;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import com.google.gson.JsonObject;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;

Expand All @@ -23,6 +30,7 @@
import java.util.stream.Collectors;

import static com.alibaba.dashscope.common.Role.*;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.stream.Collectors.toList;
Expand Down Expand Up @@ -84,7 +92,7 @@ static MultiModalMessage toQwenMultiModalMessage(ChatMessage message) {
static List<Map<String, Object>> toMultiModalContents(ChatMessage message) {
switch (message.type()) {
case USER:
return((UserMessage) message).contents()
return ((UserMessage) message).contents()
.stream()
.map(QwenHelper::toMultiModalContent)
.collect(Collectors.toList());
Expand Down Expand Up @@ -269,4 +277,102 @@ public static boolean isMultimodalModel(String modelName) {
// for now, multimodal models start with "qwen-vl"
return modelName.startsWith("qwen-vl");
}
}

/**
* build ToolFunction(ToolBase) coll from ToolSpecification coll
*
* @param toolSpecifications {@link ToolSpecification}
* @return {@link ToolFunction}
*/
static List<ToolBase> toToolFunctions(Collection<ToolSpecification> toolSpecifications) {
if (isNullOrEmpty(toolSpecifications)) {
return Collections.emptyList();
}

return toolSpecifications.stream()
.map(tool -> FunctionDefinition
.builder()
.name(tool.name())
.description(tool.description())
.parameters(toParameters(tool.parameters()))
.build()
)
.map(definition -> (ToolBase) ToolFunction
.builder()
.function(definition)
.build()
)
.collect(Collectors.toList());
}

private static JsonObject toParameters(ToolParameters toolParameters) {
QwenParameters qwenParameters = QwenParameters.from(toolParameters);
return JsonUtils.parseString(JsonUtils.toJson(qwenParameters)).getAsJsonObject();
}

/**
* Because of the interface definition, only implement the strategy of "must be called" here.{@link ChatLanguageModel}
*
* @param toolThatMustBeExecuted {@link ToolSpecification}
* @return tool choice strategy
* More details are available <a href="https://help.aliyun.com/zh/dashscope/developer-reference/api-details">here</a>.
*/
static ToolChoiceStrategy buildToolChoiceStrategy(ToolSpecification toolThatMustBeExecuted) {
return new ToolChoiceStrategy(new ToolChoiceFunction(toolThatMustBeExecuted.name()));
}

private static class ToolChoiceStrategy {
private final String type = "function";

private final ToolChoiceFunction function;

public ToolChoiceStrategy(ToolChoiceFunction function) {
this.function = function;
}
}

private static class ToolChoiceFunction {
private final String name;

public ToolChoiceFunction(String name) {
this.name = name;
}
}

static AiMessage aiMessageFrom(GenerationResult result) {
return isFunctionToolCalls(result) ?
new AiMessage(functionToolCallsFrom(result)) : new AiMessage(answerFrom(result));
}

private static List<ToolExecutionRequest> functionToolCallsFrom(GenerationResult result) {
List<ToolCallBase> toolCalls = Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getChoices)
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(Choice::getMessage)
.map(Message::getToolCalls)
.orElseThrow(IllegalStateException::new);

return toolCalls.stream()
.filter(toolCall -> toolCall instanceof ToolCallFunction)
.map(toolCall -> (ToolCallFunction) toolCall)
.map(toolCall -> ToolExecutionRequest.builder()
.id(toolCall.getId())
.name(toolCall.getFunction().getName())
.arguments(toolCall.getFunction().getArguments())
.build())
.collect(Collectors.toList());
}

static boolean isFunctionToolCalls(GenerationResult result) {
Optional<List<ToolCallBase>> toolCallBases = Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getChoices)
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(Choice::getMessage)
.map(Message::getToolCalls);
return toolCallBases.isPresent() && !isNullOrEmpty(toolCallBases.get());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package dev.langchain4j.model.dashscope;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import dev.langchain4j.agent.tool.ToolParameters;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonInclude(NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(SnakeCaseStrategy.class)
public class QwenParameters {
private String type = "object";
private Map<String, Map<String, Object>> properties;
private List<String> required;

private static final QwenParameters EMPTY_PARAMETERS_INSTANT = QwenParameters.builder().build();

public static QwenParameters from(ToolParameters toolParameters) {
if (toolParameters == null) {
return EMPTY_PARAMETERS_INSTANT;
}

return QwenParameters.builder()
.properties(toolParameters.properties())
.required(toolParameters.required())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package dev.langchain4j.model.dashscope;

import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.Collections;

import static dev.langchain4j.model.dashscope.QwenTestHelper.*;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -27,13 +33,91 @@ public void should_send_non_multimodal_messages_and_receive_response(String mode
assertThat(response.content().text()).containsIgnoringCase("rain");
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider")
public void should_call_function_with_no_argument(String modelName) {
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();

String toolName = "getCurrentDateAndTime";
ToolSpecification noArgToolSpec = ToolSpecification.builder()
.name(toolName)
.description("Get the current date and time")
.build();

UserMessage userMessage = UserMessage.from("What time is it?");

Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), Collections.singletonList(noArgToolSpec));

assertThat(response.content().text()).isNull();
assertThat(response.content().toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
assertThat(toolExecutionRequest.arguments()).isEqualTo("{}");
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider")
public void should_call_function_with_argument(String modelName) {
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();

String toolName = "getCurrentWeather";
ToolSpecification hasArgToolSpec = ToolSpecification.builder()
.name(toolName)
.description("Query the weather of a specified city")
.addParameter("cityName", JsonSchemaProperty.STRING)
.build();

UserMessage userMessage = UserMessage.from("Weather in Beijing?");

Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), Collections.singletonList(hasArgToolSpec));

assertThat(response.content().text()).isNull();
assertThat(response.content().toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
assertThat(toolExecutionRequest.arguments()).contains("Beijing");
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider")
public void should_call_must_be_executed_call_function(String modelName) {
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();

String toolName = "getCurrentWeather";
ToolSpecification mustBeExecutedTool = ToolSpecification.builder()
.name(toolName)
.description("Query the weather of a specified city")
.addParameter("cityName", JsonSchemaProperty.STRING)
.build();

// not related to tools
UserMessage userMessage = UserMessage.from("How many students in the classroom?");

Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), mustBeExecutedTool);

assertThat(response.content().text()).isNull();
assertThat(response.content().toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);
assertThat(toolExecutionRequest.arguments()).hasSizeGreaterThan(0);
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#multimodalChatModelNameProvider")
public void should_send_multimodal_image_url_and_receive_response(String modelName) {
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();;
.build();

Response<AiMessage> response = model.generate(multimodalChatMessagesWithImageUrl());
System.out.println(response);
Expand All @@ -47,7 +131,7 @@ public void should_send_multimodal_image_data_and_receive_response(String modelN
ChatLanguageModel model = QwenChatModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();;
.build();

Response<AiMessage> response = model.generate(multimodalChatMessagesWithImageData());
System.out.println(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ public static Stream<Arguments> nonMultimodalChatModelNameProvider() {
);
}

public static Stream<Arguments> functionCallChatModelNameProvider() {
return Stream.of(
Arguments.of(QwenModelName.QWEN_MAX)
);
}

public static Stream<Arguments> multimodalChatModelNameProvider() {
return Stream.of(
Arguments.of(QwenModelName.QWEN_VL_PLUS),
Expand Down

0 comments on commit d580506

Please sign in to comment.