Skip to content

Commit

Permalink
Refactor the implementation for Qwen series models using the new Dash…
Browse files Browse the repository at this point in the history
…Scope SDK APIs. (langchain4j#155)

The design of the Dashscope SDK is evolving towards OpenAI, offering new
fields and specifications. Utilize these latest features to refactor the
implementation of the Qwen models.
  • Loading branch information
jiangsier-xyz authored Sep 3, 2023
1 parent 3bffc97 commit f2bb6f9
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 78 deletions.
2 changes: 1 addition & 1 deletion langchain4j-dashscope/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.2.0</version>
<version>2.3.1</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package dev.langchain4j.model.dashscope;

import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.agent.tool.ToolSpecification;
Expand All @@ -13,21 +13,20 @@
import dev.langchain4j.model.chat.ChatLanguageModel;

import java.util.List;
import java.util.Optional;

public class QwenChatModel implements ChatLanguageModel {
protected final Generation gen;
protected final String apiKey;
protected final String modelName;
protected final Double topP;
protected final Double topK;
protected final Integer topK;
protected final Boolean enableSearch;
protected final Integer seed;

protected QwenChatModel(String apiKey,
String modelName,
Double topP,
Double topK,
Integer topK,
Boolean enableSearch,
Integer seed) {

Expand All @@ -42,18 +41,14 @@ protected QwenChatModel(String apiKey,

@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
return AiMessage.aiMessage(sendMessage(QwenParamHelper.toQwenPrompt(messages)));
return AiMessage.aiMessage(sendMessage(null, QwenHelper.toQwenMessages(messages)));
}

protected String sendMessage(String prompt) {
GenerationResult result = doSendMessage(prompt);
return Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getText)
.orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]");
protected String sendMessage(String prompt, List<Message> messages) {
return QwenHelper.answerFrom(doSendMessage(prompt, messages));
}

protected GenerationResult doSendMessage(String prompt) {
protected GenerationResult doSendMessage(String prompt, List<Message> messages) {
QwenParam param = QwenParam.builder()
.apiKey(apiKey)
.model(modelName)
Expand All @@ -62,6 +57,8 @@ protected GenerationResult doSendMessage(String prompt) {
.enableSearch(enableSearch)
.seed(seed)
.prompt(prompt)
.messages(messages)
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.build();

try {
Expand All @@ -88,7 +85,7 @@ public static class Builder {
protected String apiKey;
protected String modelName;
protected Double topP;
protected Double topK;
protected Integer topK;
protected Boolean enableSearch;
protected Integer seed;

Expand All @@ -107,7 +104,7 @@ public Builder topP(Double topP) {
return this;
}

public Builder topK(Double topK) {
public Builder topK(Integer topK) {
this.topK = topK;
return this;
}
Expand All @@ -126,11 +123,8 @@ protected void ensureOptions() {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_V1 : modelName;
topP = topP == null ? 0.8 : topP;
topK = topK == null ? 100.0 : topK;
enableSearch = enableSearch == null ? Boolean.FALSE : enableSearch;
seed = seed == null ? 1234 : seed;
modelName = Utils.isNullOrBlank(modelName) ? QwenModelName.QWEN_PLUS_V1 : modelName;
enableSearch = enableSearch != null && enableSearch;
}

public QwenChatModel build() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package dev.langchain4j.model.dashscope;

import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationOutput.Choice;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class QwenHelper {
public static List<Message> toQwenMessages(List<ChatMessage> messages) {
return messages.stream()
.map(QwenHelper::toQwenMessage)
.collect(Collectors.toList());
}

public static Message toQwenMessage(ChatMessage message) {
return Message.builder()
.role(roleFrom(message))
.content(message.text())
.build();
}

public static String roleFrom(ChatMessage message) {
if (message instanceof AiMessage) {
return Role.ASSISTANT.getValue();
} else if (message instanceof SystemMessage) {
return Role.SYSTEM.getValue();
} else {
return Role.USER.getValue();
}
}

public static String answerFrom(GenerationResult result) {
return Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getChoices)
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(Choice::getMessage)
.map(Message::getContent)
// Compatible with some older models.
.orElseGet(() -> Optional.of(result)
.map(GenerationResult::getOutput)
.map(GenerationOutput::getText)
.orElse("Oops, something wrong...[request id: " + result.getRequestId() + "]"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ public class QwenLanguageModel extends QwenChatModel implements LanguageModel {
protected QwenLanguageModel(String apiKey,
String modelName,
Double topP,
Double topK,
Integer topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}
@Override
public String process(String text) {
return sendMessage(text);
return sendMessage(text, null);
}

public static Builder builder() {
Expand All @@ -33,7 +33,7 @@ public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}

public Builder topK(Double topK) {
public Builder topK(Integer topK) {
return (Builder) super.topK(topK);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
Expand All @@ -16,33 +17,35 @@ public class QwenStreamingChatModel extends QwenChatModel implements StreamingCh
protected QwenStreamingChatModel(String apiKey,
String modelName,
Double topP,
Double topK,
Integer topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}

@Override
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) {
sendMessage(QwenParamHelper.toQwenPrompt(messages), handler);
sendMessage(null, QwenHelper.toQwenMessages(messages), handler);
}

protected void sendMessage(String prompt, StreamingResponseHandler handler) {
protected void sendMessage(String rawPrompt, List<Message> messages, StreamingResponseHandler handler) {
QwenParam param = QwenParam.builder()
.apiKey(apiKey)
.model(modelName)
.topP(topP)
.topK(topK)
.enableSearch(enableSearch)
.seed(seed)
.prompt(prompt)
.prompt(rawPrompt)
.messages(messages)
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.build();

try {
gen.call(param, new ResultCallback<GenerationResult>() {
gen.streamCall(param, new ResultCallback<GenerationResult>() {
@Override
public void onEvent(GenerationResult result) {
handler.onNext(result.getOutput().getText());
handler.onNext(QwenHelper.answerFrom(result));
}
@Override
public void onComplete() {
Expand Down Expand Up @@ -89,7 +92,7 @@ public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}

public Builder topK(Double topK) {
public Builder topK(Integer topK) {
return (Builder) super.topK(topK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ public class QwenStreamingLanguageModel extends QwenStreamingChatModel implement
protected QwenStreamingLanguageModel(String apiKey,
String modelName,
Double topP,
Double topK,
Integer topK,
Boolean enableSearch,
Integer seed) {
super(apiKey, modelName, topP, topK, enableSearch, seed);
}

@Override
public void process(String text, StreamingResponseHandler handler) {
sendMessage(text, handler);
sendMessage(text, null, handler);
}

public static Builder builder() {
Expand All @@ -35,7 +35,7 @@ public Builder topP(Double topP) {
return (Builder) super.topP(topP);
}

public Builder topK(Double topK) {
public Builder topK(Integer topK) {
return (Builder) super.topK(topK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public void should_send_messages_and_receive_response(String modelName) {
.modelName(modelName)
.build();
AiMessage answer = model.sendMessages(QwenTestHelper.chatMessages());
System.out.println(answer.text());
assertThat(answer.text()).containsIgnoringCase("rain");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public void should_send_messages_and_receive_response(String modelName) {
.apiKey(apiKey)
.modelName(modelName)
.build();
assertThat(model.process("Please say 'hello' to me")).containsIgnoringCase("hello");
String answer = model.process("Please say 'hello' to me");
System.out.println(answer);
assertThat(answer).containsIgnoringCase("hello");
}
}

0 comments on commit f2bb6f9

Please sign in to comment.