forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integration with ChatGLM (langchain4j#360)
support [chatglm](https://github.com/THUDM/ChatGLM-6B), which was mentioned in langchain4j#267 . [chatglm2](https://github.com/THUDM/ChatGLM2-6B) and [chatglm3](https://github.com/THUDM/ChatGLM3) api are compatible with openai, so It is enough to support chatglm. because chatglm does not have official docker image, so I don't know how to use `Testcontainers` to do test. (I'm not familiar with `Testcontainers`, so for now I have to copy test from the other modules, lol). The test will update using `Testcontainers` after I learn about it in few days.
- Loading branch information
Showing
10 changed files
with
339 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.25.0-SNAPSHOT</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-chatglm</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<name>LangChain4j integration with ChatGLM</name> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>retrofit</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>converter-gson</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.okhttp3</groupId> | ||
<artifactId>okhttp</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.testcontainers</groupId> | ||
<artifactId>testcontainers</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
|
||
</project> |
21 changes: 21 additions & 0 deletions
21
langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.List; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
class ChatCompletionRequest { | ||
|
||
private String prompt; | ||
private Double temperature; | ||
private Double topP; | ||
private Integer maxLength; | ||
private List<List<String>> history; | ||
} |
20 changes: 20 additions & 0 deletions
20
langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.List; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
class ChatCompletionResponse { | ||
|
||
private String response; | ||
private List<List<String>> history; | ||
private Integer status; | ||
private String time; | ||
} |
15 changes: 15 additions & 0 deletions
15
langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import retrofit2.Call; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.Headers; | ||
import retrofit2.http.POST; | ||
|
||
interface ChatGlmApi { | ||
|
||
int OK = 200; | ||
|
||
@POST("/") | ||
@Headers({"Content-Type: application/json"}) | ||
Call<ChatCompletionResponse> chatCompletion(@Body ChatCompletionRequest chatCompletionRequest); | ||
} |
84 changes: 84 additions & 0 deletions
84
langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
|
||
/** | ||
* Support <a href="https://github.com/THUDM/ChatGLM-6B">ChatGLM</a>, | ||
* ChatGLM2 and ChatGLM3 api are compatible with OpenAI API | ||
*/ | ||
public class ChatGlmChatModel implements ChatLanguageModel { | ||
|
||
private final ChatGlmClient client; | ||
private final Double temperature; | ||
private final Double topP; | ||
private final Integer maxLength; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public ChatGlmChatModel(String baseUrl, Duration timeout, | ||
Double temperature, Integer maxRetries, | ||
Double topP, Integer maxLength) { | ||
this.client = new ChatGlmClient(baseUrl, timeout); | ||
this.temperature = getOrDefault(temperature, 0.7); | ||
this.maxRetries = getOrDefault(maxRetries, 3); | ||
this.topP = topP; | ||
this.maxLength = maxLength; | ||
} | ||
|
||
|
||
@Override | ||
public Response<AiMessage> generate(List<ChatMessage> messages) { | ||
// get last user message | ||
String prompt = messages.get(messages.size() - 1).text(); | ||
List<List<String>> history = toHistory(messages.subList(0, messages.size() - 1)); | ||
ChatCompletionRequest request = ChatCompletionRequest.builder() | ||
.prompt(prompt) | ||
.temperature(temperature) | ||
.topP(topP) | ||
.maxLength(maxLength) | ||
.history(history) | ||
.build(); | ||
|
||
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries); | ||
|
||
return Response.from(AiMessage.from(response.getResponse())); | ||
} | ||
|
||
private List<List<String>> toHistory(List<ChatMessage> historyMessages) { | ||
// Order: User - AI - User - AI ... | ||
// so the length of historyMessages must be divisible by 2 | ||
if (containsSystemMessage(historyMessages)) { | ||
throw new IllegalArgumentException("ChatGLM does not support system prompt"); | ||
} | ||
|
||
if (historyMessages.size() % 2 != 0) { | ||
throw new IllegalArgumentException("History must be divisible by 2 because it's order User - AI - User - AI ..."); | ||
} | ||
|
||
List<List<String>> history = new ArrayList<>(); | ||
for (int i = 0; i < historyMessages.size() / 2; i++) { | ||
history.add(historyMessages.subList(i * 2, i * 2 + 2).stream() | ||
.map(ChatMessage::text) | ||
.collect(Collectors.toList())); | ||
} | ||
|
||
return history; | ||
} | ||
|
||
private boolean containsSystemMessage(List<ChatMessage> messages) { | ||
return messages.stream().anyMatch(message -> message.type() == ChatMessageType.SYSTEM); | ||
} | ||
} |
70 changes: 70 additions & 0 deletions
70
langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.GsonBuilder; | ||
import lombok.Builder; | ||
import okhttp3.OkHttpClient; | ||
import retrofit2.Response; | ||
import retrofit2.Retrofit; | ||
import retrofit2.converter.gson.GsonConverterFactory; | ||
|
||
import java.io.IOException; | ||
import java.time.Duration; | ||
|
||
import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; | ||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static java.time.Duration.ofSeconds; | ||
|
||
class ChatGlmClient { | ||
|
||
private final ChatGlmApi chatGLMApi; | ||
private static final Gson GSON = new GsonBuilder() | ||
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) | ||
.create(); | ||
|
||
|
||
@Builder | ||
public ChatGlmClient(String baseUrl, Duration timeout) { | ||
timeout = getOrDefault(timeout, ofSeconds(60)); | ||
|
||
OkHttpClient okHttpClient = new OkHttpClient.Builder() | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.build(); | ||
|
||
Retrofit retrofit = new Retrofit.Builder() | ||
.baseUrl(baseUrl) | ||
.client(okHttpClient) | ||
.addConverterFactory(GsonConverterFactory.create(GSON)) | ||
.build(); | ||
|
||
chatGLMApi = retrofit.create(ChatGlmApi.class); | ||
} | ||
|
||
public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { | ||
try { | ||
Response<ChatCompletionResponse> retrofitResponse | ||
= chatGLMApi.chatCompletion(request).execute(); | ||
|
||
if (retrofitResponse.isSuccessful() && retrofitResponse.body() != null | ||
&& retrofitResponse.body().getStatus() == ChatGlmApi.OK) { | ||
return retrofitResponse.body(); | ||
} else { | ||
throw toException(retrofitResponse); | ||
} | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private RuntimeException toException(Response<?> response) throws IOException { | ||
int code = response.code(); | ||
String body = response.errorBody().string(); | ||
|
||
String errorMessage = String.format("status code: %s; body: %s", code, body); | ||
return new RuntimeException(errorMessage); | ||
} | ||
|
||
} |
51 changes: 51 additions & 0 deletions
51
langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package dev.langchain4j.model.chatglm; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.data.message.UserMessage.userMessage; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
@Disabled("need local deployment of ChatGLM, see https://github.com/THUDM/ChatGLM-6B") | ||
class ChatGlmChatModelIT { | ||
|
||
ChatLanguageModel model = ChatGlmChatModel.builder() | ||
.baseUrl("http://localhost:8000") | ||
.build(); | ||
|
||
@Test | ||
void should_generate_answer() { | ||
UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?"); | ||
Response<AiMessage> response = model.generate(userMessage); | ||
assertThat(response.content().text()).contains("柏林"); | ||
} | ||
|
||
@Test | ||
void should_generate_answer_from_history() { | ||
// init history | ||
List<ChatMessage> messages = new ArrayList<>(); | ||
|
||
// given question first time | ||
UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?"); | ||
Response<AiMessage> response = model.generate(userMessage); | ||
assertThat(response.content().text()).contains("柏林"); | ||
|
||
// given question with history | ||
messages.add(userMessage); | ||
messages.add(response.content()); | ||
|
||
UserMessage secondUserMessage = userMessage("你能告诉我上个问题我问了你什么呢?请把我的问题原封不动的告诉我"); | ||
messages.add(secondUserMessage); | ||
|
||
Response<AiMessage> secondResponse = model.generate(messages); | ||
assertThat(secondResponse.content().text()).contains("德国"); // the answer should contain Germany in the First Question | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -497,4 +497,4 @@ | |
</build> | ||
</profile> | ||
</profiles> | ||
</project> | ||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters