Skip to content

Commit e4cb4a8

Browse files
committed
add streaming mode support
1 parent f79c6db commit e4cb4a8

File tree

3 files changed

+117
-31
lines changed

3 files changed

+117
-31
lines changed

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>com.javaaidev</groupId>
88
<artifactId>springai-openai-client</artifactId>
9-
<version>0.7.0</version>
9+
<version>0.8.0</version>
1010

1111
<name>OpenAI ChatModel</name>
1212
<description>Spring AI ChatModel for OpenAI using official Java SDK</description>
@@ -44,7 +44,7 @@
4444
<kotlin.code.style>official</kotlin.code.style>
4545
<kotlin.compiler.jvmTarget>${java.version}</kotlin.compiler.jvmTarget>
4646
<spring-ai.version>1.0.0</spring-ai.version>
47-
<openai-java.version>1.6.1</openai-java.version>
47+
<openai-java.version>2.5.0</openai-java.version>
4848
</properties>
4949

5050
<repositories>

src/main/kotlin/com/javaaidev/openai/OpenAIChatModel.kt

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.springframework.ai.model.tool.ToolCallingManager
2323
import org.springframework.ai.model.tool.ToolExecutionResult
2424
import org.springframework.util.MimeType
2525
import org.springframework.util.MimeTypeUtils
26+
import reactor.core.publisher.Flux
2627
import java.util.*
2728

2829
class OpenAIChatModel(
@@ -35,6 +36,81 @@ class OpenAIChatModel(
3536
private val toolExecutionEligibilityPredicate = DefaultToolExecutionEligibilityPredicate()
3637

3738
override fun call(prompt: Prompt): ChatResponse {
39+
val requestPrompt = buildRequestPrompt(prompt)
40+
return internalCall(requestPrompt, null)
41+
}
42+
43+
private fun internalCall(prompt: Prompt, previousChatResponse: ChatResponse?): ChatResponse {
44+
val completion = openAIClient.chat().completions().create(buildChatCompletionCreateParams(prompt))
45+
val generations = completion.choices().map { choice ->
46+
buildGeneration(
47+
choice, mapOf(
48+
"id" to completion.id(),
49+
"index" to choice.index(),
50+
"finishReason" to choice.finishReason().value().name
51+
)
52+
)
53+
}
54+
val response = ChatResponse.builder().generations(generations).build()
55+
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
56+
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
57+
if (toolExecutionResult.returnDirect()) {
58+
return ChatResponse.builder()
59+
.from(response)
60+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
61+
.build()
62+
} else {
63+
return this.internalCall(
64+
Prompt(toolExecutionResult.conversationHistory(), prompt.options),
65+
response
66+
)
67+
}
68+
}
69+
return response
70+
}
71+
72+
override fun stream(prompt: Prompt): Flux<ChatResponse> {
73+
val requestPrompt = buildRequestPrompt(prompt)
74+
return internalStream(requestPrompt, null)
75+
}
76+
77+
private fun internalStream(prompt: Prompt, previousChatResponse: ChatResponse?): Flux<ChatResponse> {
78+
return Flux.fromStream(openAIClient.chat().completions().createStreaming(buildChatCompletionCreateParams(prompt)).stream().map { chunk ->
79+
val generations = chunk.choices().map { choice ->
80+
buildGeneration(
81+
choice, mapOf(
82+
"id" to chunk.id(),
83+
"index" to choice.index(),
84+
"finishReason" to choice.finishReason().map { reason -> reason.value().name }.orElse("")
85+
)
86+
)
87+
}.toList()
88+
ChatResponse.builder().generations(generations).build()
89+
}).flatMap { response ->
90+
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
91+
Flux.defer {
92+
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
93+
if (toolExecutionResult.returnDirect()) {
94+
Flux.just(
95+
ChatResponse.builder()
96+
.from(response)
97+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
98+
.build()
99+
)
100+
} else {
101+
this.internalStream(
102+
Prompt(toolExecutionResult.conversationHistory(), prompt.options),
103+
response
104+
)
105+
}
106+
}
107+
} else {
108+
Flux.just(response)
109+
}
110+
}
111+
}
112+
113+
private fun buildRequestPrompt(prompt: Prompt): Prompt {
38114
var runtimeOptions: OpenAiChatOptions? = null
39115
if (prompt.options != null) {
40116
runtimeOptions = if (prompt.options is ToolCallingChatOptions) {
@@ -81,10 +157,10 @@ class OpenAIChatModel(
81157
requestOptions.toolCallbacks = this.defaultOptions.toolCallbacks
82158
requestOptions.toolContext = this.defaultOptions.toolContext
83159
}
84-
return internalCall(prompt, null)
160+
return prompt.mutate().chatOptions(requestOptions).build()
85161
}
86162

87-
private fun internalCall(prompt: Prompt, previousChatResponse: ChatResponse?): ChatResponse {
163+
private fun buildChatCompletionCreateParams(prompt: Prompt): ChatCompletionCreateParams {
88164
val paramsBuilder = ChatCompletionCreateParams.builder()
89165

90166
prompt.instructions.forEach { message ->
@@ -211,35 +287,10 @@ class OpenAIChatModel(
211287
paramsBuilder.tools(tools)
212288
}
213289
}
214-
215-
val completion = openAIClient.chat().completions().create(paramsBuilder.build())
216-
val generations = completion.choices().map { choice ->
217-
buildGeneration(
218-
choice, mapOf(
219-
"id" to completion.id(),
220-
"index" to choice.index(),
221-
"finishReason" to choice.finishReason().value().name
222-
)
223-
)
224-
}
225-
val response = ChatResponse.builder().generations(generations).build()
226-
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
227-
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
228-
if (toolExecutionResult.returnDirect()) {
229-
return ChatResponse.builder()
230-
.from(response)
231-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
232-
.build()
233-
} else {
234-
return this.internalCall(
235-
Prompt(toolExecutionResult.conversationHistory(), prompt.options),
236-
response
237-
)
238-
}
239-
}
240-
return response
290+
return paramsBuilder.build()
241291
}
242292

293+
243294
private fun buildGeneration(
244295
choice: ChatCompletion.Choice,
245296
metadata: Map<String, Any>
@@ -261,6 +312,28 @@ class OpenAIChatModel(
261312
return Generation(assistantMessage, metadataBuilder.build())
262313
}
263314

315+
private fun buildGeneration(
316+
choice: ChatCompletionChunk.Choice,
317+
metadata: Map<String, Any>
318+
): Generation {
319+
val toolCalls = choice.delta().toolCalls().map { calls ->
320+
calls.filter { it.id().isPresent }
321+
.map { toolCall ->
322+
AssistantMessage.ToolCall(
323+
toolCall.id().orElse(""),
324+
"function",
325+
toolCall.function().flatMap { it.name() }.orElse(""),
326+
toolCall.function().flatMap { it.arguments() }.orElse("")
327+
)
328+
}
329+
}.orElse(listOf())
330+
val finishReason = choice.finishReason().map { it.value().name }.orElse("")
331+
val metadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason)
332+
val assistantMessage =
333+
AssistantMessage(choice.delta().content().orElse(""), metadata, toolCalls, listOf())
334+
return Generation(assistantMessage, metadataBuilder.build())
335+
}
336+
264337
private fun fromAudioData(audioData: Any): String {
265338
return if (audioData is ByteArray) {
266339
Base64.getEncoder().encodeToString(audioData)

src/test/kotlin/com/javaaidev/openai/OpenAIChatModelTest.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.springframework.ai.tool.function.FunctionToolCallback
1010
import org.springframework.ai.tool.resolution.ToolCallbackResolver
1111
import java.util.function.Function
1212
import kotlin.test.assertNotNull
13+
import kotlin.test.assertTrue
1314

1415
class OpenAIChatModelTest {
1516
private val chatClient: ChatClient
@@ -35,6 +36,18 @@ class OpenAIChatModelTest {
3536
assertNotNull(response)
3637
}
3738

39+
@Test
40+
@DisplayName("Simple streaming completion")
41+
fun testStreamCompletion() {
42+
val builder = StringBuilder()
43+
chatClient.prompt().user("tell me a joke")
44+
.stream().chatResponse().doOnNext {
45+
builder.append(it.result.output.text)
46+
}.blockLast()
47+
val result = builder.toString()
48+
assertTrue { result.isNotEmpty() }
49+
}
50+
3851
@Test
3952
@DisplayName("Tool calling")
4053
fun testToolCalling() {

0 commit comments

Comments
 (0)