From 17edd7b50ccfcfd1b244f2e72c772164b719596a Mon Sep 17 00:00:00 2001
From: ZYinNJU <1754350460@qq.com>
Date: Tue, 19 Dec 2023 22:09:50 +0800
Subject: [PATCH] Integration with ChatGLM (#360)
support [chatglm](https://github.com/THUDM/ChatGLM-6B), which was
mentioned in #267 .
[chatglm2](https://github.com/THUDM/ChatGLM2-6B) and
[chatglm3](https://github.com/THUDM/ChatGLM3) api are compatible with
openai, so It is enough to support chatglm.
because chatglm does not have official docker image, so I don't know how
to use `Testcontainers` to do test. (I'm not familiar with
`Testcontainers`, so for now I have to copy test from the other modules,
lol). The test will update using `Testcontainers` after I learn about it
in few days.
---
langchain4j-bom/pom.xml | 6 ++
langchain4j-chatglm/pom.xml | 70 ++++++++++++++++
.../model/chatglm/ChatCompletionRequest.java | 21 +++++
.../model/chatglm/ChatCompletionResponse.java | 20 +++++
.../langchain4j/model/chatglm/ChatGlmApi.java | 15 ++++
.../model/chatglm/ChatGlmChatModel.java | 84 +++++++++++++++++++
.../model/chatglm/ChatGlmClient.java | 70 ++++++++++++++++
.../model/chatglm/ChatGlmChatModelIT.java | 51 +++++++++++
langchain4j-parent/pom.xml | 2 +-
pom.xml | 1 +
10 files changed, 339 insertions(+), 1 deletion(-)
create mode 100644 langchain4j-chatglm/pom.xml
create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java
create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java
create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java
create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java
create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java
create mode 100644 langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java
diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml
index 8df0b78ffe6..d5ee718643d 100644
--- a/langchain4j-bom/pom.xml
+++ b/langchain4j-bom/pom.xml
@@ -87,6 +87,12 @@
${project.version}
+
+ dev.langchain4j
+ langchain4j-chatglm
+ ${project.version}
+
+
diff --git a/langchain4j-chatglm/pom.xml b/langchain4j-chatglm/pom.xml
new file mode 100644
index 00000000000..cd1e1ed03e5
--- /dev/null
+++ b/langchain4j-chatglm/pom.xml
@@ -0,0 +1,70 @@
+
+
+ 4.0.0
+
+ dev.langchain4j
+ langchain4j-parent
+ 0.25.0-SNAPSHOT
+ ../langchain4j-parent/pom.xml
+
+
+ langchain4j-chatglm
+ jar
+
+ LangChain4j integration with ChatGLM
+
+
+
+ dev.langchain4j
+ langchain4j-core
+
+
+
+ com.squareup.retrofit2
+ retrofit
+
+
+
+ com.squareup.retrofit2
+ converter-gson
+
+
+
+ com.squareup.okhttp3
+ okhttp
+
+
+
+ org.projectlombok
+ lombok
+ provided
+
+
+
+ org.junit.jupiter
+ junit-jupiter-engine
+ test
+
+
+
+ org.junit.jupiter
+ junit-jupiter-params
+ test
+
+
+
+ org.assertj
+ assertj-core
+ test
+
+
+
+ org.testcontainers
+ testcontainers
+ test
+
+
+
+
\ No newline at end of file
diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java
new file mode 100644
index 00000000000..0fdece86301
--- /dev/null
+++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java
@@ -0,0 +1,21 @@
+package dev.langchain4j.model.chatglm;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+@Builder
+class ChatCompletionRequest {
+
+ private String prompt;
+ private Double temperature;
+ private Double topP;
+ private Integer maxLength;
+ private List> history;
+}
diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java
new file mode 100644
index 00000000000..531177525aa
--- /dev/null
+++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java
@@ -0,0 +1,20 @@
+package dev.langchain4j.model.chatglm;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+@Builder
+class ChatCompletionResponse {
+
+ private String response;
+ private List> history;
+ private Integer status;
+ private String time;
+}
diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java
new file mode 100644
index 00000000000..363f4727fea
--- /dev/null
+++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java
@@ -0,0 +1,15 @@
+package dev.langchain4j.model.chatglm;
+
+import retrofit2.Call;
+import retrofit2.http.Body;
+import retrofit2.http.Headers;
+import retrofit2.http.POST;
+
+interface ChatGlmApi {
+
+ int OK = 200;
+
+ @POST("/")
+ @Headers({"Content-Type: application/json"})
+ Call chatCompletion(@Body ChatCompletionRequest chatCompletionRequest);
+}
diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java
new file mode 100644
index 00000000000..a6a3c9db21c
--- /dev/null
+++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java
@@ -0,0 +1,84 @@
+package dev.langchain4j.model.chatglm;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ChatMessageType;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import lombok.Builder;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static dev.langchain4j.internal.RetryUtils.withRetry;
+import static dev.langchain4j.internal.Utils.getOrDefault;
+
+/**
+ * Support ChatGLM,
+ * ChatGLM2 and ChatGLM3 api are compatible with OpenAI API
+ */
+public class ChatGlmChatModel implements ChatLanguageModel {
+
+ private final ChatGlmClient client;
+ private final Double temperature;
+ private final Double topP;
+ private final Integer maxLength;
+ private final Integer maxRetries;
+
+ @Builder
+ public ChatGlmChatModel(String baseUrl, Duration timeout,
+ Double temperature, Integer maxRetries,
+ Double topP, Integer maxLength) {
+ this.client = new ChatGlmClient(baseUrl, timeout);
+ this.temperature = getOrDefault(temperature, 0.7);
+ this.maxRetries = getOrDefault(maxRetries, 3);
+ this.topP = topP;
+ this.maxLength = maxLength;
+ }
+
+
+ @Override
+ public Response generate(List messages) {
+ // get last user message
+ String prompt = messages.get(messages.size() - 1).text();
+ List> history = toHistory(messages.subList(0, messages.size() - 1));
+ ChatCompletionRequest request = ChatCompletionRequest.builder()
+ .prompt(prompt)
+ .temperature(temperature)
+ .topP(topP)
+ .maxLength(maxLength)
+ .history(history)
+ .build();
+
+ ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries);
+
+ return Response.from(AiMessage.from(response.getResponse()));
+ }
+
+ private List> toHistory(List historyMessages) {
+ // Order: User - AI - User - AI ...
+ // so the length of historyMessages must be divisible by 2
+ if (containsSystemMessage(historyMessages)) {
+ throw new IllegalArgumentException("ChatGLM does not support system prompt");
+ }
+
+ if (historyMessages.size() % 2 != 0) {
+ throw new IllegalArgumentException("History must be divisible by 2 because it's order User - AI - User - AI ...");
+ }
+
+ List> history = new ArrayList<>();
+ for (int i = 0; i < historyMessages.size() / 2; i++) {
+ history.add(historyMessages.subList(i * 2, i * 2 + 2).stream()
+ .map(ChatMessage::text)
+ .collect(Collectors.toList()));
+ }
+
+ return history;
+ }
+
+ private boolean containsSystemMessage(List messages) {
+ return messages.stream().anyMatch(message -> message.type() == ChatMessageType.SYSTEM);
+ }
+}
diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java
new file mode 100644
index 00000000000..21e88a143b4
--- /dev/null
+++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java
@@ -0,0 +1,70 @@
+package dev.langchain4j.model.chatglm;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import lombok.Builder;
+import okhttp3.OkHttpClient;
+import retrofit2.Response;
+import retrofit2.Retrofit;
+import retrofit2.converter.gson.GsonConverterFactory;
+
+import java.io.IOException;
+import java.time.Duration;
+
+import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
+import static dev.langchain4j.internal.Utils.getOrDefault;
+import static java.time.Duration.ofSeconds;
+
+class ChatGlmClient {
+
+ private final ChatGlmApi chatGLMApi;
+ private static final Gson GSON = new GsonBuilder()
+ .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
+ .create();
+
+
+ @Builder
+ public ChatGlmClient(String baseUrl, Duration timeout) {
+ timeout = getOrDefault(timeout, ofSeconds(60));
+
+ OkHttpClient okHttpClient = new OkHttpClient.Builder()
+ .callTimeout(timeout)
+ .connectTimeout(timeout)
+ .readTimeout(timeout)
+ .writeTimeout(timeout)
+ .build();
+
+ Retrofit retrofit = new Retrofit.Builder()
+ .baseUrl(baseUrl)
+ .client(okHttpClient)
+ .addConverterFactory(GsonConverterFactory.create(GSON))
+ .build();
+
+ chatGLMApi = retrofit.create(ChatGlmApi.class);
+ }
+
+ public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) {
+ try {
+ Response retrofitResponse
+ = chatGLMApi.chatCompletion(request).execute();
+
+ if (retrofitResponse.isSuccessful() && retrofitResponse.body() != null
+ && retrofitResponse.body().getStatus() == ChatGlmApi.OK) {
+ return retrofitResponse.body();
+ } else {
+ throw toException(retrofitResponse);
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private RuntimeException toException(Response> response) throws IOException {
+ int code = response.code();
+ String body = response.errorBody().string();
+
+ String errorMessage = String.format("status code: %s; body: %s", code, body);
+ return new RuntimeException(errorMessage);
+ }
+
+}
diff --git a/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java b/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java
new file mode 100644
index 00000000000..e7d9401900a
--- /dev/null
+++ b/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java
@@ -0,0 +1,51 @@
+package dev.langchain4j.model.chatglm;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static dev.langchain4j.data.message.UserMessage.userMessage;
+import static org.assertj.core.api.Assertions.assertThat;
+
+@Disabled("need local deployment of ChatGLM, see https://github.com/THUDM/ChatGLM-6B")
+class ChatGlmChatModelIT {
+
+ ChatLanguageModel model = ChatGlmChatModel.builder()
+ .baseUrl("http://localhost:8000")
+ .build();
+
+ @Test
+ void should_generate_answer() {
+ UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?");
+ Response response = model.generate(userMessage);
+ assertThat(response.content().text()).contains("柏林");
+ }
+
+ @Test
+ void should_generate_answer_from_history() {
+ // init history
+ List messages = new ArrayList<>();
+
+ // given question first time
+ UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?");
+ Response response = model.generate(userMessage);
+ assertThat(response.content().text()).contains("柏林");
+
+ // given question with history
+ messages.add(userMessage);
+ messages.add(response.content());
+
+ UserMessage secondUserMessage = userMessage("你能告诉我上个问题我问了你什么呢?请把我的问题原封不动的告诉我");
+ messages.add(secondUserMessage);
+
+ Response secondResponse = model.generate(messages);
+ assertThat(secondResponse.content().text()).contains("德国"); // the answer should contain Germany in the First Question
+ }
+}
diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml
index 01a6a803e4f..da6e4801815 100644
--- a/langchain4j-parent/pom.xml
+++ b/langchain4j-parent/pom.xml
@@ -497,4 +497,4 @@
-
+
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 07901a56960..76df616d9ec 100644
--- a/pom.xml
+++ b/pom.xml
@@ -27,6 +27,7 @@
langchain4j-open-ai
langchain4j-vertex-ai
langchain4j-ollama
+ langchain4j-chatglm
langchain4j-cassandra