Skip to content

Commit a502d83

Browse files
committed
Merge PR TheoKanning#413 Add support for gpt-4-vision by @JanCong
with a few modifications that removes parameterized `ChatMessage`
1 parent 4ece8c2 commit a502d83

File tree

14 files changed

+367
-7
lines changed

14 files changed

+367
-7
lines changed

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class ChatCompletionChoice {
1515
Integer index;
1616

1717
/**
18-
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
18+
* The {@link ChatMessageRole#ASSISTANT} message or delta (when streaming) which was generated
1919
*/
2020
@JsonAlias("delta")
2121
ChatMessage message;
@@ -25,4 +25,10 @@ public class ChatCompletionChoice {
2525
*/
2626
@JsonProperty("finish_reason")
2727
String finishReason;
28+
29+
/**
30+
* When use the GPT-4V model, this will be return, for example {"type":"stop","stop":"<|fim_suffix|>"}.
31+
*/
32+
@JsonProperty("finish_details")
33+
FinishDetails finishDetails;
2834
}

api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package com.theokanning.openai.completion.chat;
22

3+
import com.fasterxml.jackson.annotation.JsonIgnore;
34
import com.fasterxml.jackson.annotation.JsonInclude;
45
import com.fasterxml.jackson.annotation.JsonProperty;
5-
import lombok.*;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import lombok.AllArgsConstructor;
8+
import lombok.Data;
9+
import lombok.NoArgsConstructor;
610

711
import java.util.List;
812

@@ -28,7 +32,7 @@ public class ChatMessage {
2832
*/
2933
String role;
3034
@JsonInclude() // content should always exist in the call, even if it is null
31-
String content;
35+
Object content;
3236
//name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
3337
String name;
3438

@@ -38,6 +42,17 @@ public class ChatMessage {
3842
@JsonProperty("function_call")
3943
ChatFunctionCall functionCall;
4044

45+
public ChatMessage(String role, List<ChatMessageContent> content) {
46+
this.role = role == null ? "assistant" : role;
47+
this.content = content;
48+
}
49+
50+
public ChatMessage(String role, List<ChatMessageContent> content, String name) {
51+
this.role = role == null ? "assistant" : role;
52+
this.content = content;
53+
this.name = name;
54+
}
55+
4156
public ChatMessage(String role, String content) {
4257
this.role = role == null ? "assistant" : role;
4358
this.content = content;
@@ -49,4 +64,9 @@ public ChatMessage(String role, String content, String name) {
4964
this.name = name;
5065
}
5166

67+
@JsonIgnore
68+
public String getStringContent() {
69+
return content instanceof String ? (String) content : null;
70+
}
71+
5272
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import lombok.Data;
5+
import lombok.NoArgsConstructor;
6+
7+
@Data
8+
@NoArgsConstructor
9+
public class ChatMessageContent {
10+
11+
/**
12+
* The type of the content part
13+
*
14+
* @see ChatMessageContentType
15+
*/
16+
private String type;
17+
18+
/**
19+
* The text content.
20+
*/
21+
private String text;
22+
23+
/**
24+
* Image input is only supported when using the gpt-4-visual-preview model.
25+
*/
26+
@JsonProperty("image_url")
27+
private ImageUrl imageUrl;
28+
29+
public ChatMessageContent(String text) {
30+
this.type = ChatMessageContentType.TEXT.value();
31+
this.text = text;
32+
}
33+
34+
public ChatMessageContent(ImageUrl imageUrl) {
35+
this.type = ChatMessageContentType.IMAGE_URL.value();
36+
this.imageUrl = imageUrl;
37+
}
38+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
/**
4+
* see {@link ChatMessage} documentation.
5+
*/
6+
public enum ChatMessageContentType {
7+
8+
TEXT("text"),
9+
IMAGE_URL("image_url");
10+
11+
private final String value;
12+
13+
ChatMessageContentType(final String value) {
14+
this.value = value;
15+
}
16+
17+
public String value() {
18+
return value;
19+
}
20+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
public interface Content {
4+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.Data;
4+
5+
/**
6+
* finish_details
7+
*
8+
* @author cong
9+
* @since 2023/12/3
10+
*/
11+
@Data
12+
public class FinishDetails {
13+
14+
/**
15+
* The reason why GPT stopped generating, for example "stop", "max_tokens".
16+
*/
17+
private String type;
18+
19+
/**
20+
* For example "<|fim_suffix|>"
21+
*/
22+
private String stop;
23+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.*;
4+
5+
@Data
6+
@AllArgsConstructor
7+
@NoArgsConstructor
8+
@RequiredArgsConstructor
9+
public class ImageUrl {
10+
11+
/**
12+
* Either a URL of the image or the base64 encoded image data.
13+
*/
14+
@NonNull
15+
private String url;
16+
17+
/**
18+
* Specifies the detail level of the image. Learn more in the
19+
* <a href="https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding">
20+
* Vision guide</a>.
21+
*/
22+
private String detail;
23+
}

api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
186186
int sum = 0;
187187
for (ChatMessage msg : messages) {
188188
sum += tokensPerMessage;
189-
sum += tokens(encoding, msg.getContent());
189+
if (msg.getContent() instanceof String) sum += tokens(encoding, (String) msg.getContent());
190190
sum += tokens(encoding, msg.getRole());
191191
sum += tokens(encoding, msg.getName());
192192
if (isNotBlank(msg.getName())) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.theokanning.openai.utils;
2+
3+
import com.theokanning.openai.completion.chat.ChatMessage;
4+
import com.theokanning.openai.completion.chat.ChatMessageContent;
5+
import com.theokanning.openai.completion.chat.ImageUrl;
6+
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.regex.Matcher;
10+
import java.util.regex.Pattern;
11+
12+
/**
13+
* Vision tool class
14+
*
15+
* @author cong
16+
* @since 2023/11/17
17+
*/
18+
public class VisionUtil {
19+
20+
private static final Pattern pattern = Pattern.compile("(https?://\\S+)");
21+
22+
public static ChatMessage convertForVision(ChatMessage msg) {
23+
List<ChatMessageContent> content = new ArrayList<>();
24+
String sourceText = msg.getStringContent();
25+
// Regular expression to match image URLs
26+
Matcher matcher = pattern.matcher(sourceText);
27+
// Find image URLs and split the string
28+
int lastIndex = 0;
29+
while (matcher.find()) {
30+
// Add the text before the image URL
31+
if (matcher.start() > lastIndex) {
32+
content.add(new ChatMessageContent(sourceText.substring(lastIndex, matcher.start()).trim()));
33+
}
34+
// Add the image URL
35+
ImageUrl imageUrl = new ImageUrl();
36+
imageUrl.setUrl(matcher.group());
37+
content.add(new ChatMessageContent(imageUrl));
38+
lastIndex = matcher.end();
39+
}
40+
// Add the remaining text
41+
if (lastIndex < sourceText.length()) {
42+
content.add(new ChatMessageContent(sourceText.substring(lastIndex).trim()));
43+
}
44+
return new ChatMessage(msg.getRole(), content, msg.getName());
45+
}
46+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package example;
2+
3+
import com.theokanning.openai.completion.chat.*;
4+
import com.theokanning.openai.service.OpenAiService;
5+
import com.theokanning.openai.utils.VisionUtil;
6+
7+
import java.time.Duration;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
class OpenAiApiVisionExample {
12+
public static void main(String... args) {
13+
String token = System.getenv("OPENAI_TOKEN");
14+
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(30));
15+
16+
System.out.println("Streaming chat completion...");
17+
final List<ChatMessage> messages = new ArrayList<>();
18+
List<ChatMessageContent> content = new ArrayList<>();
19+
content.add(new ChatMessageContent("What’s in this image?"));
20+
content.add(new ChatMessageContent(new ImageUrl(
21+
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")));
22+
messages.add(new ChatMessage(ChatMessageRole.USER.value(), content));
23+
24+
// use VisionUtil to convert image prompt to OpenAI format
25+
System.out.println("Converting image to OpenAI format...");
26+
ChatMessage visionChatMessage = VisionUtil.convertForVision(
27+
new ChatMessage(ChatMessageRole.USER.value(),
28+
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg "
29+
+ "What are in these images? Is there any difference between them?"));
30+
messages.add(visionChatMessage);
31+
32+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
33+
.builder()
34+
.model("gpt-4-vision-preview")
35+
.messages(messages)
36+
.maxTokens(300)
37+
.build();
38+
39+
service.streamChatCompletion(chatCompletionRequest).blockingForEach(System.out::println);
40+
service.shutdownExecutor();
41+
}
42+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
5+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
6+
7+
/**
8+
* @author cong
9+
* @since 2023/11/17
10+
*/
11+
public abstract class ChatMessageMixIn {
12+
@JsonProperty("content")
13+
@JsonSerialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentSerializer.class)
14+
@JsonDeserialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentDeserializer.class)
15+
abstract Object getContent();
16+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.fasterxml.jackson.core.JsonGenerator;
4+
import com.fasterxml.jackson.core.JsonParser;
5+
import com.fasterxml.jackson.databind.*;
6+
import com.theokanning.openai.completion.chat.*;
7+
8+
import java.io.IOException;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Optional;
12+
13+
public class ChatMessageSerializerAndDeserializer {
14+
15+
public static class ChatMessageContentSerializer extends JsonSerializer<Object> {
16+
@Override
17+
public void serialize(Object content, JsonGenerator gen, SerializerProvider serializers) throws IOException {
18+
if (content == null) {
19+
gen.writeNull();
20+
return;
21+
}
22+
if (content instanceof String) {
23+
gen.writeString((String) content);
24+
return;
25+
}
26+
if (content instanceof List) {
27+
gen.writeStartArray();
28+
List<?> contentList = (List<?>)content;
29+
for (Object item : contentList) {
30+
if (item instanceof ChatMessageContent) {
31+
ChatMessageContent contentItem = (ChatMessageContent)item;
32+
gen.writeStartObject();
33+
gen.writeStringField("type", contentItem.getType());
34+
if (ChatMessageContentType.TEXT.value().equals(contentItem.getType())) {
35+
gen.writeStringField("text", contentItem.getText());
36+
} else if (ChatMessageContentType.IMAGE_URL.value().equals(contentItem.getType())) {
37+
gen.writeObjectFieldStart("image_url");
38+
gen.writeStringField("url", contentItem.getImageUrl().getUrl());
39+
gen.writeStringField("detail", contentItem.getImageUrl().getDetail());
40+
gen.writeEndObject();
41+
}
42+
gen.writeEndObject();
43+
}
44+
}
45+
gen.writeEndArray();
46+
}
47+
}
48+
}
49+
50+
public static class ChatMessageContentDeserializer extends JsonDeserializer<Object> {
51+
@Override
52+
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
53+
JsonNode contentNode = p.readValueAsTree();
54+
if (contentNode.isTextual()) {
55+
return contentNode.asText();
56+
}
57+
if (contentNode.isArray()) {
58+
List<Object> contentList = new ArrayList<>();
59+
for (JsonNode itemNode : contentNode) {
60+
String type = itemNode.get("type").asText();
61+
if (ChatMessageContentType.TEXT.value().equals(type)) {
62+
contentList.add(new ChatMessageContent(itemNode.get("text").asText()));
63+
} else if (ChatMessageContentType.IMAGE_URL.value().equals(type)) {
64+
JsonNode imageUrlJsonNode = itemNode.get("image_url");
65+
ImageUrl imageUrl = new ImageUrl();
66+
imageUrl.setUrl(Optional.ofNullable(imageUrlJsonNode.get("url"))
67+
.map(JsonNode::asText).orElse(null));
68+
imageUrl.setDetail(Optional.ofNullable(imageUrlJsonNode.get("detail"))
69+
.map(JsonNode::asText).orElse(null));
70+
contentList.add(new ChatMessageContent(imageUrl));
71+
}
72+
}
73+
return contentList;
74+
}
75+
return null;
76+
}
77+
}
78+
79+
}

0 commit comments

Comments
 (0)