Skip to content

Commit 4ece8c2

Browse files
committed
Merge PR TheoKanning#402 Added response_format capabilities to chat completion request by @PrimosK
1 parent 59bbeb3 commit 4ece8c2

File tree

2 files changed

+110
-40
lines changed

2 files changed

+110
-40
lines changed

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
@NoArgsConstructor
1616
public class ChatCompletionRequest {
1717

18-
/**
19-
* ID of the model to use.
20-
*/
21-
String model;
22-
2318
/**
2419
* The messages to generate chat completions for, in the <a
2520
* href="https://platform.openai.com/docs/guides/chat/introduction">chat format</a>.<br>
@@ -28,36 +23,38 @@ public class ChatCompletionRequest {
2823
List<ChatMessage> messages;
2924

3025
/**
31-
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
32-
* values like 0.2 will make it more focused and deterministic.<br>
33-
* We generally recommend altering this or top_p but not both.
34-
*/
35-
Double temperature;
36-
37-
/**
38-
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
39-
* with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>
40-
* We generally recommend altering this or temperature but not both.
26+
* ID of the model to use.
4127
*/
42-
@JsonProperty("top_p")
43-
Double topP;
28+
String model;
4429

4530
/**
46-
* How many chat completion chatCompletionChoices to generate for each input message.
31+
* Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
32+
* decreasing the model's likelihood to repeat the same line verbatim.
4733
*/
48-
Integer n;
34+
@JsonProperty("frequency_penalty")
35+
Double frequencyPenalty;
4936

5037
/**
51-
* If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only <a
52-
* href="https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format">server-sent
53-
* events</a> as they become available, with the stream terminated by a data: [DONE] message.
38+
* <p>An object specifying the format that the model must output.</p>
39+
*
40+
* <p>Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.</p>
41+
*
42+
* <p><b>Important:</b> when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.
43+
* Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting
44+
* in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if
45+
* finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.</p>
5446
*/
55-
Boolean stream;
47+
@JsonProperty("response_format")
48+
ResponseFormat responseFormat;
5649

5750
/**
58-
* Up to 4 sequences where the API will stop generating further tokens.
51+
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
52+
* to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
53+
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
54+
* should result in a ban or exclusive selection of the relevant token.
5955
*/
60-
List<String> stop;
56+
@JsonProperty("logit_bias")
57+
Map<String, Integer> logitBias;
6158

6259
/**
6360
* The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will
@@ -66,6 +63,11 @@ public class ChatCompletionRequest {
6663
@JsonProperty("max_tokens")
6764
Integer maxTokens;
6865

66+
/**
67+
* How many chat completion chatCompletionChoices to generate for each input message.
68+
*/
69+
Integer n;
70+
6971
/**
7072
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
7173
* increasing the model's likelihood to talk about new topics.
@@ -74,38 +76,48 @@ public class ChatCompletionRequest {
7476
Double presencePenalty;
7577

7678
/**
77-
* Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
78-
* decreasing the model's likelihood to repeat the same line verbatim.
79+
* Up to 4 sequences where the API will stop generating further tokens.
7980
*/
80-
@JsonProperty("frequency_penalty")
81-
Double frequencyPenalty;
81+
List<String> stop;
8282

8383
/**
84-
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
85-
* to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
86-
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
87-
* should result in a ban or exclusive selection of the relevant token.
84+
* If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only <a
85+
* href="https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format">server-sent
86+
* events</a> as they become available, with the stream terminated by a data: [DONE] message.
8887
*/
89-
@JsonProperty("logit_bias")
90-
Map<String, Integer> logitBias;
88+
Boolean stream;
9189

90+
/**
91+
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
92+
* values like 0.2 will make it more focused and deterministic.<br>
93+
* We generally recommend altering this or top_p but not both.
94+
*/
95+
Double temperature;
9296

9397
/**
94-
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
98+
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
99+
* with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>
100+
* We generally recommend altering this or temperature but not both.
95101
*/
96-
String user;
102+
@JsonProperty("top_p")
103+
Double topP;
97104

98105
/**
99-
* A list of the available functions.
106+
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
100107
*/
101-
List<?> functions;
108+
String user;
102109

103110
/**
104111
* Controls how the model responds to function calls, as specified in the <a href="https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call">OpenAI documentation</a>.
105112
*/
106113
@JsonProperty("function_call")
107114
ChatCompletionRequestFunctionCall functionCall;
108115

116+
/**
117+
* A list of the available functions.
118+
*/
119+
List<?> functions;
120+
109121
@Data
110122
@Builder
111123
@AllArgsConstructor
@@ -119,6 +131,19 @@ public static ChatCompletionRequestFunctionCall of(String name) {
119131

120132
}
121133

134+
@Data
135+
@Builder
136+
@AllArgsConstructor
137+
@NoArgsConstructor
138+
public static class ResponseFormat {
139+
String type;
140+
141+
public static ResponseFormat of(String type) {
142+
return new ResponseFormat(type);
143+
}
144+
145+
}
146+
122147
/**
123148
* A list of tools the model may call. Currently, only functions are supported as a tool.
124149
*/

service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import com.fasterxml.jackson.annotation.JsonProperty;
44
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5+
import com.fasterxml.jackson.core.JsonParser;
56
import com.fasterxml.jackson.databind.JsonNode;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
68
import com.fasterxml.jackson.databind.node.ObjectNode;
79
import com.theokanning.openai.completion.chat.*;
810
import org.junit.jupiter.api.Test;
911

12+
import java.io.IOException;
1013
import java.util.*;
1114

1215
import static org.junit.jupiter.api.Assertions.*;
@@ -23,7 +26,7 @@ static class Weather {
2326
}
2427

2528
enum WeatherUnit {
26-
CELSIUS, FAHRENHEIT;
29+
CELSIUS, FAHRENHEIT
2730
}
2831

2932
static class WeatherResponse {
@@ -375,4 +378,46 @@ void createChatCompletionWithToolFunctions() {
375378
assertNotNull(choice2.getMessage().getContent());
376379
}
377380

381+
@Test
382+
void streamChatCompletionWithJsonResponseFormat() {
383+
final List<ChatMessage> messages = new ArrayList<>();
384+
385+
// The system message is deliberately vague in order to not give too much of a direction of how response should look like.
386+
// The main gist there is that chat competition should always contain JSON content.
387+
final ChatMessage systemMessage = new ChatMessage(
388+
ChatMessageRole.SYSTEM.value(),
389+
"You are a dog and will speak as such - but please do it in JSON."
390+
);
391+
392+
messages.add(systemMessage);
393+
394+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
395+
.builder()
396+
.model("gpt-4-1106-preview")
397+
.messages(messages)
398+
.n(1)
399+
.maxTokens(256)
400+
.responseFormat(ChatCompletionRequest.ResponseFormat.of("json_object"))
401+
.build();
402+
403+
ChatCompletionResult chatCompletion = service.createChatCompletion(chatCompletionRequest);
404+
405+
ChatCompletionChoice chatCompletionChoice = chatCompletion.getChoices().get(0);
406+
String expectedJsonContent = chatCompletionChoice.getMessage().getContent();
407+
408+
assertTrue(isValidJSON(expectedJsonContent), "Invalid JSON response:\n\n" + expectedJsonContent);
409+
}
410+
411+
private boolean isValidJSON(String json) {
412+
try (final JsonParser parser = new ObjectMapper().createParser(json)) {
413+
while (parser.nextToken() != null) {
414+
// Just try to read all tokens in order to verify whether this is valid json.
415+
}
416+
return true;
417+
} catch (IOException ioe) {
418+
ioe.printStackTrace();
419+
return false;
420+
}
421+
}
422+
378423
}

0 commit comments

Comments
 (0)