From 17edd7b50ccfcfd1b244f2e72c772164b719596a Mon Sep 17 00:00:00 2001 From: ZYinNJU <1754350460@qq.com> Date: Tue, 19 Dec 2023 22:09:50 +0800 Subject: [PATCH] Integration with ChatGLM (#360) support [chatglm](https://github.com/THUDM/ChatGLM-6B), which was mentioned in #267 . [chatglm2](https://github.com/THUDM/ChatGLM2-6B) and [chatglm3](https://github.com/THUDM/ChatGLM3) api are compatible with openai, so It is enough to support chatglm. because chatglm does not have official docker image, so I don't know how to use `Testcontainers` to do test. (I'm not familiar with `Testcontainers`, so for now I have to copy test from the other modules, lol). The test will update using `Testcontainers` after I learn about it in few days. --- langchain4j-bom/pom.xml | 6 ++ langchain4j-chatglm/pom.xml | 70 ++++++++++++++++ .../model/chatglm/ChatCompletionRequest.java | 21 +++++ .../model/chatglm/ChatCompletionResponse.java | 20 +++++ .../langchain4j/model/chatglm/ChatGlmApi.java | 15 ++++ .../model/chatglm/ChatGlmChatModel.java | 84 +++++++++++++++++++ .../model/chatglm/ChatGlmClient.java | 70 ++++++++++++++++ .../model/chatglm/ChatGlmChatModelIT.java | 51 +++++++++++ langchain4j-parent/pom.xml | 2 +- pom.xml | 1 + 10 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 langchain4j-chatglm/pom.xml create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java create mode 100644 langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java create mode 100644 langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml index 8df0b78ffe6..d5ee718643d 100644 --- a/langchain4j-bom/pom.xml +++ b/langchain4j-bom/pom.xml @@ -87,6 +87,12 @@ ${project.version} + + dev.langchain4j + langchain4j-chatglm + ${project.version} + + diff --git a/langchain4j-chatglm/pom.xml b/langchain4j-chatglm/pom.xml new file mode 100644 index 00000000000..cd1e1ed03e5 --- /dev/null +++ b/langchain4j-chatglm/pom.xml @@ -0,0 +1,70 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.25.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-chatglm + jar + + LangChain4j integration with ChatGLM + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-gson + + + + com.squareup.okhttp3 + okhttp + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + org.testcontainers + testcontainers + test + + + + \ No newline at end of file diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java new file mode 100644 index 00000000000..0fdece86301 --- /dev/null +++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionRequest.java @@ -0,0 +1,21 @@ +package dev.langchain4j.model.chatglm; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatCompletionRequest { + + private String prompt; + private Double temperature; + private Double topP; + private Integer maxLength; + private List> history; +} diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java new file mode 100644 index 00000000000..531177525aa --- /dev/null +++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatCompletionResponse.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.chatglm; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatCompletionResponse { + + private String response; + private List> history; + private Integer status; + private String time; +} diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java new file mode 100644 index 00000000000..363f4727fea --- /dev/null +++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmApi.java @@ -0,0 +1,15 @@ +package dev.langchain4j.model.chatglm; + +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.Headers; +import retrofit2.http.POST; + +interface ChatGlmApi { + + int OK = 200; + + @POST("/") + @Headers({"Content-Type: application/json"}) + Call chatCompletion(@Body ChatCompletionRequest chatCompletionRequest); +} diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java new file mode 100644 index 00000000000..a6a3c9db21c --- /dev/null +++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmChatModel.java @@ -0,0 +1,84 @@ +package dev.langchain4j.model.chatglm; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.Builder; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; + +/** + * Support ChatGLM, + * ChatGLM2 and ChatGLM3 api are compatible with OpenAI API + */ +public class ChatGlmChatModel implements ChatLanguageModel { + + private final ChatGlmClient client; + private final Double temperature; + private final Double topP; + private final Integer maxLength; + private final Integer maxRetries; + + @Builder + public ChatGlmChatModel(String baseUrl, Duration timeout, + Double temperature, Integer maxRetries, + Double topP, Integer maxLength) { + this.client = new ChatGlmClient(baseUrl, timeout); + this.temperature = getOrDefault(temperature, 0.7); + this.maxRetries = getOrDefault(maxRetries, 3); + this.topP = topP; + this.maxLength = maxLength; + } + + + @Override + public Response generate(List messages) { + // get last user message + String prompt = messages.get(messages.size() - 1).text(); + List> history = toHistory(messages.subList(0, messages.size() - 1)); + ChatCompletionRequest request = ChatCompletionRequest.builder() + .prompt(prompt) + .temperature(temperature) + .topP(topP) + .maxLength(maxLength) + .history(history) + .build(); + + ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries); + + return Response.from(AiMessage.from(response.getResponse())); + } + + private List> toHistory(List historyMessages) { + // Order: User - AI - User - AI ... + // so the length of historyMessages must be divisible by 2 + if (containsSystemMessage(historyMessages)) { + throw new IllegalArgumentException("ChatGLM does not support system prompt"); + } + + if (historyMessages.size() % 2 != 0) { + throw new IllegalArgumentException("History must be divisible by 2 because it's order User - AI - User - AI ..."); + } + + List> history = new ArrayList<>(); + for (int i = 0; i < historyMessages.size() / 2; i++) { + history.add(historyMessages.subList(i * 2, i * 2 + 2).stream() + .map(ChatMessage::text) + .collect(Collectors.toList())); + } + + return history; + } + + private boolean containsSystemMessage(List messages) { + return messages.stream().anyMatch(message -> message.type() == ChatMessageType.SYSTEM); + } +} diff --git a/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java new file mode 100644 index 00000000000..21e88a143b4 --- /dev/null +++ b/langchain4j-chatglm/src/main/java/dev/langchain4j/model/chatglm/ChatGlmClient.java @@ -0,0 +1,70 @@ +package dev.langchain4j.model.chatglm; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import lombok.Builder; +import okhttp3.OkHttpClient; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.gson.GsonConverterFactory; + +import java.io.IOException; +import java.time.Duration; + +import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static java.time.Duration.ofSeconds; + +class ChatGlmClient { + + private final ChatGlmApi chatGLMApi; + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) + .create(); + + + @Builder + public ChatGlmClient(String baseUrl, Duration timeout) { + timeout = getOrDefault(timeout, ofSeconds(60)); + + OkHttpClient okHttpClient = new OkHttpClient.Builder() + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout) + .build(); + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(baseUrl) + .client(okHttpClient) + .addConverterFactory(GsonConverterFactory.create(GSON)) + .build(); + + chatGLMApi = retrofit.create(ChatGlmApi.class); + } + + public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { + try { + Response retrofitResponse + = chatGLMApi.chatCompletion(request).execute(); + + if (retrofitResponse.isSuccessful() && retrofitResponse.body() != null + && retrofitResponse.body().getStatus() == ChatGlmApi.OK) { + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private RuntimeException toException(Response response) throws IOException { + int code = response.code(); + String body = response.errorBody().string(); + + String errorMessage = String.format("status code: %s; body: %s", code, body); + return new RuntimeException(errorMessage); + } + +} diff --git a/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java b/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java new file mode 100644 index 00000000000..e7d9401900a --- /dev/null +++ b/langchain4j-chatglm/src/test/java/dev/langchain4j/model/chatglm/ChatGlmChatModelIT.java @@ -0,0 +1,51 @@ +package dev.langchain4j.model.chatglm; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static org.assertj.core.api.Assertions.assertThat; + +@Disabled("need local deployment of ChatGLM, see https://github.com/THUDM/ChatGLM-6B") +class ChatGlmChatModelIT { + + ChatLanguageModel model = ChatGlmChatModel.builder() + .baseUrl("http://localhost:8000") + .build(); + + @Test + void should_generate_answer() { + UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?"); + Response response = model.generate(userMessage); + assertThat(response.content().text()).contains("柏林"); + } + + @Test + void should_generate_answer_from_history() { + // init history + List messages = new ArrayList<>(); + + // given question first time + UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?"); + Response response = model.generate(userMessage); + assertThat(response.content().text()).contains("柏林"); + + // given question with history + messages.add(userMessage); + messages.add(response.content()); + + UserMessage secondUserMessage = userMessage("你能告诉我上个问题我问了你什么呢?请把我的问题原封不动的告诉我"); + messages.add(secondUserMessage); + + Response secondResponse = model.generate(messages); + assertThat(secondResponse.content().text()).contains("德国"); // the answer should contain Germany in the First Question + } +} diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml index 01a6a803e4f..da6e4801815 100644 --- a/langchain4j-parent/pom.xml +++ b/langchain4j-parent/pom.xml @@ -497,4 +497,4 @@ - + \ No newline at end of file diff --git a/pom.xml b/pom.xml index 07901a56960..76df616d9ec 100644 --- a/pom.xml +++ b/pom.xml @@ -27,6 +27,7 @@ langchain4j-open-ai langchain4j-vertex-ai langchain4j-ollama + langchain4j-chatglm langchain4j-cassandra