Skip to content

Commit

Permalink
Integration with ChatGLM (langchain4j#360)
Browse files Browse the repository at this point in the history
support [chatglm](https://github.com/THUDM/ChatGLM-6B), which was
mentioned in langchain4j#267 .

[chatglm2](https://github.com/THUDM/ChatGLM2-6B) and
[chatglm3](https://github.com/THUDM/ChatGLM3) api are compatible with
openai, so It is enough to support chatglm.

because chatglm does not have official docker image, so I don't know how
to use `Testcontainers` to do test. (I'm not familiar with
`Testcontainers`, so for now I have to copy test from the other modules,
lol). The test will update using `Testcontainers` after I learn about it
in few days.
  • Loading branch information
Martin7-1 authored Dec 19, 2023
1 parent 968bb71 commit 17edd7b
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 1 deletion.
6 changes: 6 additions & 0 deletions langchain4j-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chatglm</artifactId>
<version>${project.version}</version>
</dependency>

<!-- embedding stores -->

<dependency>
Expand Down
70 changes: 70 additions & 0 deletions langchain4j-chatglm/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.25.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-chatglm</artifactId>
<packaging>jar</packaging>

<name>LangChain4j integration with ChatGLM</name>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package dev.langchain4j.model.chatglm;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
class ChatCompletionRequest {

private String prompt;
private Double temperature;
private Double topP;
private Integer maxLength;
private List<List<String>> history;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package dev.langchain4j.model.chatglm;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
class ChatCompletionResponse {

private String response;
private List<List<String>> history;
private Integer status;
private String time;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.chatglm;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Headers;
import retrofit2.http.POST;

interface ChatGlmApi {

int OK = 200;

@POST("/")
@Headers({"Content-Type: application/json"})
Call<ChatCompletionResponse> chatCompletion(@Body ChatCompletionRequest chatCompletionRequest);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package dev.langchain4j.model.chatglm;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;

/**
* Support <a href="https://github.com/THUDM/ChatGLM-6B">ChatGLM</a>,
* ChatGLM2 and ChatGLM3 api are compatible with OpenAI API
*/
public class ChatGlmChatModel implements ChatLanguageModel {

private final ChatGlmClient client;
private final Double temperature;
private final Double topP;
private final Integer maxLength;
private final Integer maxRetries;

@Builder
public ChatGlmChatModel(String baseUrl, Duration timeout,
Double temperature, Integer maxRetries,
Double topP, Integer maxLength) {
this.client = new ChatGlmClient(baseUrl, timeout);
this.temperature = getOrDefault(temperature, 0.7);
this.maxRetries = getOrDefault(maxRetries, 3);
this.topP = topP;
this.maxLength = maxLength;
}


@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
// get last user message
String prompt = messages.get(messages.size() - 1).text();
List<List<String>> history = toHistory(messages.subList(0, messages.size() - 1));
ChatCompletionRequest request = ChatCompletionRequest.builder()
.prompt(prompt)
.temperature(temperature)
.topP(topP)
.maxLength(maxLength)
.history(history)
.build();

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries);

return Response.from(AiMessage.from(response.getResponse()));
}

private List<List<String>> toHistory(List<ChatMessage> historyMessages) {
// Order: User - AI - User - AI ...
// so the length of historyMessages must be divisible by 2
if (containsSystemMessage(historyMessages)) {
throw new IllegalArgumentException("ChatGLM does not support system prompt");
}

if (historyMessages.size() % 2 != 0) {
throw new IllegalArgumentException("History must be divisible by 2 because it's order User - AI - User - AI ...");
}

List<List<String>> history = new ArrayList<>();
for (int i = 0; i < historyMessages.size() / 2; i++) {
history.add(historyMessages.subList(i * 2, i * 2 + 2).stream()
.map(ChatMessage::text)
.collect(Collectors.toList()));
}

return history;
}

private boolean containsSystemMessage(List<ChatMessage> messages) {
return messages.stream().anyMatch(message -> message.type() == ChatMessageType.SYSTEM);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package dev.langchain4j.model.chatglm;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.time.Duration.ofSeconds;

class ChatGlmClient {

private final ChatGlmApi chatGLMApi;
private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.create();


@Builder
public ChatGlmClient(String baseUrl, Duration timeout) {
timeout = getOrDefault(timeout, ofSeconds(60));

OkHttpClient okHttpClient = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();

chatGLMApi = retrofit.create(ChatGlmApi.class);
}

public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) {
try {
Response<ChatCompletionResponse> retrofitResponse
= chatGLMApi.chatCompletion(request).execute();

if (retrofitResponse.isSuccessful() && retrofitResponse.body() != null
&& retrofitResponse.body().getStatus() == ChatGlmApi.OK) {
return retrofitResponse.body();
} else {
throw toException(retrofitResponse);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private RuntimeException toException(Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();

String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package dev.langchain4j.model.chatglm;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static org.assertj.core.api.Assertions.assertThat;

@Disabled("need local deployment of ChatGLM, see https://github.com/THUDM/ChatGLM-6B")
class ChatGlmChatModelIT {

ChatLanguageModel model = ChatGlmChatModel.builder()
.baseUrl("http://localhost:8000")
.build();

@Test
void should_generate_answer() {
UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?");
Response<AiMessage> response = model.generate(userMessage);
assertThat(response.content().text()).contains("柏林");
}

@Test
void should_generate_answer_from_history() {
// init history
List<ChatMessage> messages = new ArrayList<>();

// given question first time
UserMessage userMessage = userMessage("你好,请问一下德国的首都是哪里呢?");
Response<AiMessage> response = model.generate(userMessage);
assertThat(response.content().text()).contains("柏林");

// given question with history
messages.add(userMessage);
messages.add(response.content());

UserMessage secondUserMessage = userMessage("你能告诉我上个问题我问了你什么呢?请把我的问题原封不动的告诉我");
messages.add(secondUserMessage);

Response<AiMessage> secondResponse = model.generate(messages);
assertThat(secondResponse.content().text()).contains("德国"); // the answer should contain Germany in the First Question
}
}
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -497,4 +497,4 @@
</build>
</profile>
</profiles>
</project>
</project>
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<module>langchain4j-open-ai</module>
<module>langchain4j-vertex-ai</module>
<module>langchain4j-ollama</module>
<module>langchain4j-chatglm</module>

<!-- embedding stores -->
<module>langchain4j-cassandra</module>
Expand Down

0 comments on commit 17edd7b

Please sign in to comment.