Skip to content

Commit

Permalink
add glm-4v model (langchain4j#1469)
Browse files Browse the repository at this point in the history
## Issue
<!-- Please specify the ID of the issue this PR is addressing. For
Closes langchain4j#1468 

## Change
add glm-4v model

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)

## Checklist for adding new model integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
  • Loading branch information
1402564807 authored Jul 16, 2024
1 parent 9dd431d commit b971adf
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/docs/integrations/language-models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ sidebar_position: 0
| [OpenAI](/integrations/language-models/open-ai) |||| Compatible with: Groq, Ollama, LM Studio, GPT4All, etc. ||
| [Qianfan](/integrations/language-models/qianfan) ||| | | |
| [Cloudflare Workers AI](/integrations/language-models/workers-ai) | | | | | |
| [Zhipu AI](/integrations/language-models/zhipu-ai) ||| | | |
| [Zhipu AI](/integrations/language-models/zhipu-ai) ||| | | |
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.zhipu.chat.Content;
import dev.langchain4j.model.zhipu.chat.*;
import dev.langchain4j.model.zhipu.embedding.EmbeddingResponse;
import dev.langchain4j.model.zhipu.shared.ErrorResponse;
Expand Down Expand Up @@ -82,9 +86,27 @@ private static Message toZhipuAiMessage(ChatMessage message) {

if (message instanceof UserMessage) {
UserMessage userMessage = (UserMessage) message;
return dev.langchain4j.model.zhipu.chat.UserMessage.builder()
.content(userMessage.singleText())
.build();
if (userMessage.hasSingleText()) {
return dev.langchain4j.model.zhipu.chat.UserMessage.from(userMessage.singleText());
}
List<Content> contents = new ArrayList<>(userMessage.contents().size());
userMessage.contents().forEach(content -> {
if (content instanceof TextContent) {
TextContent textContent = (TextContent) content;
contents.add(dev.langchain4j.model.zhipu.chat.TextContent.builder()
.text(textContent.text())
.build());
}
if (content instanceof ImageContent) {
Image image = ((ImageContent) content).image();
contents.add(dev.langchain4j.model.zhipu.chat.ImageContent.builder()
.imageUrl(dev.langchain4j.model.zhipu.chat.Image.builder()
.url(image.url() != null ? image.url().toString() : image.base64Data())
.build())
.build());
}
});
return dev.langchain4j.model.zhipu.chat.UserMessage.from(contents);
}

if (message instanceof AiMessage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.zhipu.chat.ChatCompletionModel;
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
import dev.langchain4j.model.zhipu.chat.ToolChoiceMode;
import dev.langchain4j.model.zhipu.spi.ZhipuAiStreamingChatModelBuilderFactory;
Expand All @@ -22,6 +23,7 @@

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.*;
import static dev.langchain4j.model.zhipu.chat.ChatCompletionModel.GLM_4;
Expand Down Expand Up @@ -126,5 +128,16 @@ public static class ZhipuAiStreamingChatModelBuilder {
public ZhipuAiStreamingChatModelBuilder() {

}

public ZhipuAiStreamingChatModelBuilder model(ChatCompletionModel model) {
this.model = model.toString();
return this;
}

public ZhipuAiStreamingChatModelBuilder model(String model) {
ensureNotBlank(model, "model");
this.model = model;
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

public enum ChatCompletionModel {
GLM_4("glm-4"),
GLM_4V("glm-4v"),
GLM_4_0520("glm-4-0520"),
GLM_4_AIR("glm-4-air"),
GLM_4_AIRX("glm-4-airx"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package dev.langchain4j.model.zhipu.chat;

public interface Content {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dev.langchain4j.model.zhipu.chat;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Builder;
import lombok.Data;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@Builder
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@JsonIgnoreProperties(ignoreUnknown = true)
public final class Image {
private String url;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package dev.langchain4j.model.zhipu.chat;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Builder;
import lombok.Data;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@Builder
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@JsonIgnoreProperties(ignoreUnknown = true)
public class ImageContent implements Content {
@Builder.Default
private String type = "image_url";
private Image imageUrl;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package dev.langchain4j.model.zhipu.chat;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Builder;
import lombok.Data;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@Builder
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@JsonIgnoreProperties(ignoreUnknown = true)
public final class TextContent implements Content {
@Builder.Default
private String type = "text";
private String text;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import lombok.Builder;
import lombok.Data;

import java.util.Arrays;
import java.util.List;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
import static dev.langchain4j.model.zhipu.chat.Role.USER;

Expand All @@ -19,12 +22,24 @@ public final class UserMessage implements Message {

@Builder.Default
private Role role = USER;
private String content;
private Object content;
private String name;

public static UserMessage from(String text) {
return UserMessage.builder()
.content(text)
.build();
}

public static UserMessage from(List<Content> contents) {
return UserMessage.builder()
.content(contents)
.build();
}

public static UserMessage from(Content ...contents) {
return UserMessage.builder()
.content(Arrays.asList(contents))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.zhipu.chat.ChatCompletionModel;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -285,4 +291,43 @@ public void onError(ChatModelErrorContext errorContext) {
assertThat(throwable).isInstanceOf(ZhipuAiException.class);
assertThat(throwable).hasMessageContaining("Authorization Token非法,请确认Authorization Token正确传递。");
}

@Test
public void should_send_multimodal_image_data_and_receive_response() {
ChatLanguageModel model = ZhipuAiChatModel.builder()
.apiKey(apiKey)
.model(ChatCompletionModel.GLM_4V)
.build();

Response<AiMessage> response = model.generate(multimodalChatMessagesWithImageData());
System.out.println(response);

assertThat(response.content().text()).containsIgnoringCase("parrot");
assertThat(response.content().text()).endsWith("That's all!");
}

public static List<ChatMessage> multimodalChatMessagesWithImageData() {
Image image = Image.builder()
.base64Data(multimodalImageData())
.build();
ImageContent imageContent = ImageContent.from(image);
TextContent textContent = TextContent.from("What animal is in the picture? When you're done, end with \"That's all!\".");
return Collections.singletonList(UserMessage.from(imageContent, textContent));
}

public static String multimodalImageData() {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
try (InputStream in = ZhipuAiChatModelIT.class.getResourceAsStream("/parrot.jpg")) {
assertThat(in).isNotNull();
byte[] data = new byte[512];
int n;
while ((n = in.read(data)) != -1) {
buffer.write(data, 0, n);
}
} catch (IOException e) {
Assertions.fail("", e.getMessage());
}

return Base64.getEncoder().encodeToString(buffer.toByteArray());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.zhipu.chat.ChatCompletionModel;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

Expand All @@ -23,6 +24,7 @@
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.*;
import static dev.langchain4j.model.zhipu.ZhipuAiChatModelIT.multimodalChatMessagesWithImageData;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -344,4 +346,19 @@ public void onComplete(Response<AiMessage> response) {

assertThat(errorReference.get()).isInstanceOf(ZhipuAiException.class);
}

@Test
public void should_send_multimodal_image_data_and_receive_response() {
StreamingChatLanguageModel model = ZhipuAiStreamingChatModel.builder()
.apiKey(apiKey)
.model(ChatCompletionModel.GLM_4V)
.build();
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(multimodalChatMessagesWithImageData(), handler);
Response<AiMessage> response = handler.get();
System.out.println(response);

assertThat(response.content().text()).containsIgnoringCase("parrot");
assertThat(response.content().text()).endsWith("That's all!");
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b971adf

Please sign in to comment.