@@ -23,6 +23,7 @@ import org.springframework.ai.model.tool.ToolCallingManager
23
23
import org.springframework.ai.model.tool.ToolExecutionResult
24
24
import org.springframework.util.MimeType
25
25
import org.springframework.util.MimeTypeUtils
26
+ import reactor.core.publisher.Flux
26
27
import java.util.*
27
28
28
29
class OpenAIChatModel (
@@ -35,6 +36,81 @@ class OpenAIChatModel(
35
36
private val toolExecutionEligibilityPredicate = DefaultToolExecutionEligibilityPredicate ()
36
37
37
38
override fun call (prompt : Prompt ): ChatResponse {
39
+ val requestPrompt = buildRequestPrompt(prompt)
40
+ return internalCall(requestPrompt, null )
41
+ }
42
+
43
+ private fun internalCall (prompt : Prompt , previousChatResponse : ChatResponse ? ): ChatResponse {
44
+ val completion = openAIClient.chat().completions().create(buildChatCompletionCreateParams(prompt))
45
+ val generations = completion.choices().map { choice ->
46
+ buildGeneration(
47
+ choice, mapOf (
48
+ " id" to completion.id(),
49
+ " index" to choice.index(),
50
+ " finishReason" to choice.finishReason().value().name
51
+ )
52
+ )
53
+ }
54
+ val response = ChatResponse .builder().generations(generations).build()
55
+ if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
56
+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
57
+ if (toolExecutionResult.returnDirect()) {
58
+ return ChatResponse .builder()
59
+ .from(response)
60
+ .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
61
+ .build()
62
+ } else {
63
+ return this .internalCall(
64
+ Prompt (toolExecutionResult.conversationHistory(), prompt.options),
65
+ response
66
+ )
67
+ }
68
+ }
69
+ return response
70
+ }
71
+
72
+ override fun stream (prompt : Prompt ): Flux <ChatResponse > {
73
+ val requestPrompt = buildRequestPrompt(prompt)
74
+ return internalStream(requestPrompt, null )
75
+ }
76
+
77
+ private fun internalStream (prompt : Prompt , previousChatResponse : ChatResponse ? ): Flux <ChatResponse > {
78
+ return Flux .fromStream(openAIClient.chat().completions().createStreaming(buildChatCompletionCreateParams(prompt)).stream().map { chunk ->
79
+ val generations = chunk.choices().map { choice ->
80
+ buildGeneration(
81
+ choice, mapOf (
82
+ " id" to chunk.id(),
83
+ " index" to choice.index(),
84
+ " finishReason" to choice.finishReason().map { reason -> reason.value().name }.orElse(" " )
85
+ )
86
+ )
87
+ }.toList()
88
+ ChatResponse .builder().generations(generations).build()
89
+ }).flatMap { response ->
90
+ if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
91
+ Flux .defer {
92
+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
93
+ if (toolExecutionResult.returnDirect()) {
94
+ Flux .just(
95
+ ChatResponse .builder()
96
+ .from(response)
97
+ .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
98
+ .build()
99
+ )
100
+ } else {
101
+ this .internalStream(
102
+ Prompt (toolExecutionResult.conversationHistory(), prompt.options),
103
+ response
104
+ )
105
+ }
106
+ }
107
+ } else {
108
+ Flux .just(response)
109
+ }
110
+ }
111
+ }
112
+
113
+ private fun buildRequestPrompt (prompt : Prompt ): Prompt {
38
114
var runtimeOptions: OpenAiChatOptions ? = null
39
115
if (prompt.options != null ) {
40
116
runtimeOptions = if (prompt.options is ToolCallingChatOptions ) {
@@ -81,10 +157,10 @@ class OpenAIChatModel(
81
157
requestOptions.toolCallbacks = this .defaultOptions.toolCallbacks
82
158
requestOptions.toolContext = this .defaultOptions.toolContext
83
159
}
84
- return internalCall( prompt, null )
160
+ return prompt.mutate().chatOptions(requestOptions).build( )
85
161
}
86
162
87
- private fun internalCall (prompt : Prompt , previousChatResponse : ChatResponse ? ): ChatResponse {
163
+ private fun buildChatCompletionCreateParams (prompt : Prompt ): ChatCompletionCreateParams {
88
164
val paramsBuilder = ChatCompletionCreateParams .builder()
89
165
90
166
prompt.instructions.forEach { message ->
@@ -211,35 +287,10 @@ class OpenAIChatModel(
211
287
paramsBuilder.tools(tools)
212
288
}
213
289
}
214
-
215
- val completion = openAIClient.chat().completions().create(paramsBuilder.build())
216
- val generations = completion.choices().map { choice ->
217
- buildGeneration(
218
- choice, mapOf (
219
- " id" to completion.id(),
220
- " index" to choice.index(),
221
- " finishReason" to choice.finishReason().value().name
222
- )
223
- )
224
- }
225
- val response = ChatResponse .builder().generations(generations).build()
226
- if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
227
- val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
228
- if (toolExecutionResult.returnDirect()) {
229
- return ChatResponse .builder()
230
- .from(response)
231
- .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
232
- .build()
233
- } else {
234
- return this .internalCall(
235
- Prompt (toolExecutionResult.conversationHistory(), prompt.options),
236
- response
237
- )
238
- }
239
- }
240
- return response
290
+ return paramsBuilder.build()
241
291
}
242
292
293
+
243
294
private fun buildGeneration (
244
295
choice : ChatCompletion .Choice ,
245
296
metadata : Map <String , Any >
@@ -261,6 +312,28 @@ class OpenAIChatModel(
261
312
return Generation (assistantMessage, metadataBuilder.build())
262
313
}
263
314
315
+ private fun buildGeneration (
316
+ choice : ChatCompletionChunk .Choice ,
317
+ metadata : Map <String , Any >
318
+ ): Generation {
319
+ val toolCalls = choice.delta().toolCalls().map { calls ->
320
+ calls.filter { it.id().isPresent }
321
+ .map { toolCall ->
322
+ AssistantMessage .ToolCall (
323
+ toolCall.id().orElse(" " ),
324
+ " function" ,
325
+ toolCall.function().flatMap { it.name() }.orElse(" " ),
326
+ toolCall.function().flatMap { it.arguments() }.orElse(" " )
327
+ )
328
+ }
329
+ }.orElse(listOf ())
330
+ val finishReason = choice.finishReason().map { it.value().name }.orElse(" " )
331
+ val metadataBuilder = ChatGenerationMetadata .builder().finishReason(finishReason)
332
+ val assistantMessage =
333
+ AssistantMessage (choice.delta().content().orElse(" " ), metadata, toolCalls, listOf ())
334
+ return Generation (assistantMessage, metadataBuilder.build())
335
+ }
336
+
264
337
private fun fromAudioData (audioData : Any ): String {
265
338
return if (audioData is ByteArray ) {
266
339
Base64 .getEncoder().encodeToString(audioData)
0 commit comments